When updating a policy's 'platform' field, the aggregated policy stats are now cleared. (#18415)

#18157
When updating a policy's 'platform' field, the aggregated policy stats
are now cleared.

# Checklist for submitter
- [x] Changes file added for user-visible changes in `changes/`,
`orbit/changes/` or `ee/fleetd-chrome/changes`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- [x] Added/updated tests
- [x] Manual QA for all new/changed functionality
This commit is contained in:
Victor Lyuboslavsky 2024-04-29 10:20:59 -05:00 committed by GitHub
parent 0b66bc4a9b
commit d0f0d3d017
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 527 additions and 58 deletions

View file

@ -0,0 +1 @@
When updating a policy's 'platform' field, the aggregated policy stats are now cleared.

View file

@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"sort"
"strings"
@ -14,6 +15,7 @@ import (
"github.com/doug-martin/goqu/v9"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
kitlog "github.com/go-kit/kit/log"
"github.com/go-kit/kit/log/level"
"github.com/jmoiron/sqlx"
)
@ -110,8 +112,8 @@ func policyDB(ctx context.Context, q sqlx.QueryerContext, id uint, teamID *uint)
// SavePolicy updates some fields of the given policy on the datastore.
//
// Currently SavePolicy does not allow updating the team of an existing policy.
func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error {
// Currently, SavePolicy does not allow updating the team of an existing policy.
func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool, removePolicyStats bool) error {
// We must normalize the name for full Unicode support (Unicode equivalence).
p.Name = norm.NFC.String(p.Name)
sql := `
@ -133,10 +135,39 @@ func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemo
return ctxerr.Wrap(ctx, notFound("Policy").WithID(p.ID))
}
return cleanupPolicy(ctx, ds.writer(ctx), p.ID, p.Platform, shouldRemoveAllPolicyMemberships, removePolicyStats, ds.logger)
}
func cleanupPolicy(
ctx context.Context, extContext sqlx.ExtContext, policyID uint, policyPlatform string, shouldRemoveAllPolicyMemberships bool,
removePolicyStats bool, logger kitlog.Logger,
) error {
var err error
if shouldRemoveAllPolicyMemberships {
return ds.cleanupPolicyMembershipForPolicy(ctx, p.ID)
err = cleanupPolicyMembershipForPolicy(ctx, extContext, policyID)
} else {
err = cleanupPolicyMembershipOnPolicyUpdate(ctx, extContext, policyID, policyPlatform)
}
return cleanupPolicyMembershipOnPolicyUpdate(ctx, ds.writer(ctx), p.ID, p.Platform)
if err != nil {
return err
}
if removePolicyStats {
// delete all policy stats for the policy
fn := func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, `DELETE FROM policy_stats WHERE policy_id = ?`, policyID)
return err
}
if _, isDB := extContext.(*sqlx.DB); isDB {
// wrapping in a retry to avoid deadlocks with the cleanups_then_aggregation cron job
err = withRetryTxx(ctx, extContext.(*sqlx.DB), fn, logger)
} else {
err = fn(extContext)
}
if err != nil {
return ctxerr.Wrap(ctx, err, "cleanup policy stats")
}
}
return nil
}
// FlippingPoliciesForHost fetches previous policy membership results and returns:
@ -576,8 +607,74 @@ func (ds *Datastore) TeamPolicy(ctx context.Context, teamID uint, policyID uint)
// NOTE: Similar to ApplyQueries, ApplyPolicySpecs will update the author_id of the policies
// that are updated.
//
// Currently ApplyPolicySpecs does not allow updating the team of an existing policy.
// Currently, ApplyPolicySpecs does not allow updating the team of an existing policy.
func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs []*fleet.PolicySpec) error {
// Use the same DB for all operations in this method for performance
queryerContext := ds.writer(ctx)
// Preprocess specs and group them by team
teamNameToID := make(map[string]uint, 1)
teamIDToPolicies := make(map[uint][]*fleet.PolicySpec, 1)
// Get the team IDs
for _, spec := range specs {
// We must normalize the name for full Unicode support (Unicode equivalence).
spec.Name = norm.NFC.String(spec.Name)
spec.Team = norm.NFC.String(spec.Team)
teamID, ok := teamNameToID[spec.Team]
if !ok {
if spec.Team != "" {
// if team name is not empty, it must have a team ID; otherwise teamID defaults to 0 value
err := sqlx.GetContext(ctx, queryerContext, &teamID, `SELECT id FROM teams WHERE name = ?`, spec.Team)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ctxerr.Wrap(ctx, notFound("Team").WithName(spec.Team), "get team id")
}
return ctxerr.Wrap(ctx, err, "get team id")
}
}
teamNameToID[spec.Team] = teamID
}
teamIDToPolicies[teamID] = append(teamIDToPolicies[teamID], spec)
}
// Get the query and platforms of the current policies so that we can check if query or platform changed later, if needed
type policyLite struct {
Name string `db:"name"`
Query string `db:"query"`
Platforms string `db:"platforms"`
}
teamIDToPoliciesByName := make(map[uint]map[string]policyLite, len(teamIDToPolicies))
for teamID, teamPolicySpecs := range teamIDToPolicies {
teamIDToPoliciesByName[teamID] = make(map[string]policyLite, len(teamPolicySpecs))
policyNames := make([]string, 0, len(teamPolicySpecs))
for _, spec := range teamPolicySpecs {
policyNames = append(policyNames, spec.Name)
}
var query string
var args []interface{}
var err error
if teamID == 0 {
query, args, err = sqlx.In("SELECT name, query, platforms FROM policies WHERE team_id IS NULL AND name IN (?)", policyNames)
} else {
query, args, err = sqlx.In(
"SELECT name, query, platforms FROM policies WHERE team_id = ? AND name IN (?)", &teamID, policyNames,
)
}
if err != nil {
return ctxerr.Wrap(ctx, err, "building query to get policies by name")
}
policies := make([]policyLite, 0, len(teamPolicySpecs))
err = sqlx.SelectContext(ctx, queryerContext, &policies, query, args...)
if err != nil {
return ctxerr.Wrap(ctx, err, "getting policies by name")
}
for _, p := range policies {
teamIDToPoliciesByName[teamID][p.Name] = p
}
}
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
query := fmt.Sprintf(
`
@ -592,7 +689,7 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs
critical,
calendar_events_enabled,
checksum
) VALUES ( ?, ?, ?, ?, ?, (SELECT IFNULL(MIN(id), NULL) FROM teams WHERE name = ?), ?, ?, ?, %s)
) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, %s)
ON DUPLICATE KEY UPDATE
query = VALUES(query),
description = VALUES(description),
@ -603,24 +700,45 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs
calendar_events_enabled = VALUES(calendar_events_enabled)
`, policiesChecksumComputedColumn(),
)
for _, spec := range specs {
// We must normalize the name for full Unicode support (Unicode equivalence).
spec.Name = norm.NFC.String(spec.Name)
res, err := tx.ExecContext(ctx,
query, spec.Name, spec.Query, spec.Description, authorID, spec.Resolution, spec.Team, spec.Platform, spec.Critical,
spec.CalendarEventsEnabled,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "exec ApplyPolicySpecs insert")
for teamID, teamPolicySpecs := range teamIDToPolicies {
var teamIDPtr *uint
if teamID != 0 {
teamIDPtr = &teamID
}
for _, spec := range teamPolicySpecs {
if insertOnDuplicateDidUpdate(res) {
// when the upsert results in an UPDATE that *did* change some values,
// it returns the updated ID as last inserted id.
if lastID, _ := res.LastInsertId(); lastID > 0 {
if err := cleanupPolicyMembershipOnPolicyUpdate(ctx, tx, uint(lastID), spec.Platform); err != nil {
return err
res, err := tx.ExecContext(
ctx,
query, spec.Name, spec.Query, spec.Description, authorID, spec.Resolution, teamIDPtr, spec.Platform, spec.Critical,
spec.CalendarEventsEnabled,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "exec ApplyPolicySpecs insert")
}
if insertOnDuplicateDidUpdate(res) {
// when the upsert results in an UPDATE that *did* change some values,
// it returns the updated ID as last inserted id.
if lastID, _ := res.LastInsertId(); lastID > 0 {
var (
shouldRemoveAllPolicyMemberships bool
removePolicyStats bool
)
// Figure out if the query or platform changed
if prev, ok := teamIDToPoliciesByName[teamID][spec.Name]; ok {
switch {
case prev.Query != spec.Query:
shouldRemoveAllPolicyMemberships = true
removePolicyStats = true
case prev.Platforms != spec.Platform:
removePolicyStats = true
}
}
if err = cleanupPolicy(
ctx, tx, uint(lastID), spec.Platform, shouldRemoveAllPolicyMemberships, removePolicyStats, ds.logger,
); err != nil {
return err
}
}
}
}
@ -739,7 +857,7 @@ func cleanupPolicyMembershipOnPolicyUpdate(ctx context.Context, db sqlx.ExecerCo
// cleanupPolicyMembership is similar to cleanupPolicyMembershipOnPolicyUpdate but without the platform constraints.
// Used when we want to remove all policy membership.
func (ds *Datastore) cleanupPolicyMembershipForPolicy(ctx context.Context, policyID uint) error {
func cleanupPolicyMembershipForPolicy(ctx context.Context, exec sqlx.ExecerContext, policyID uint) error {
// delete all policy memberships for the policy
delStmt := `
DELETE
@ -754,21 +872,11 @@ func (ds *Datastore) cleanupPolicyMembershipForPolicy(ctx context.Context, polic
pm.policy_id = ?
`
_, err := ds.writer(ctx).ExecContext(ctx, delStmt, policyID)
_, err := exec.ExecContext(ctx, delStmt, policyID)
if err != nil {
return ctxerr.Wrap(ctx, err, "cleanup policy membership")
}
// delete all policy stats for the policy
// wrapping in a retry to avoid deadlocks with the cleanups_then_aggregation cron job
err = ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, `DELETE FROM policy_stats WHERE policy_id = ?`, policyID)
return err
})
if err != nil {
return ctxerr.Wrap(ctx, err, "cleanup policy stats")
}
return nil
}

View file

@ -40,6 +40,7 @@ func TestPolicies(t *testing.T) {
{"PoliciesByID", testPoliciesByID},
{"TeamPolicyTransfer", testTeamPolicyTransfer},
{"ApplyPolicySpec", testApplyPolicySpec},
{"ApplyPolicySpecWithQueryPlatformChanges", testApplyPolicySpecWithQueryPlatformChanges},
{"Save", testPoliciesSave},
{"DelUser", testPoliciesDelUser},
{"FlippingPoliciesForHost", testFlippingPoliciesForHost},
@ -1400,6 +1401,298 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) {
}))
}
func testApplyPolicySpecWithQueryPlatformChanges(t *testing.T, ds *Datastore) {
ctx := context.Background()
unicode, _ := strconv.Unquote(`"\uAC00"`) // 가
unicodeEq, _ := strconv.Unquote(`"\u1100\u1161"`) // ᄀ + ᅡ
user1 := test.NewUser(t, ds, "User1", "user1@example.com", true)
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1" + unicode})
require.NoError(t, err)
globalNames := []string{"global query1" + unicode, "global query2" + unicode, "global query3" + unicode}
teamNames := []string{"team query1", "team query2", "team query3"}
require.NoError(
t, ds.ApplyPolicySpecs(
ctx, user1.ID, []*fleet.PolicySpec{
{
Name: globalNames[0],
Query: "select 1;",
Team: "",
Platform: "",
},
{
Name: globalNames[1],
Query: "select 2;",
Team: "",
Platform: "darwin",
},
{
Name: globalNames[2],
Query: "select 3;",
Team: "",
Platform: "darwin,linux",
},
{
Name: teamNames[0],
Query: "select 1;",
Team: "team1" + unicode,
Platform: "",
},
{
Name: teamNames[1],
Query: "select 2;",
Team: "team1" + unicode,
Platform: "darwin",
},
{
Name: teamNames[2],
Query: "select 3;",
Team: "team1" + unicodeEq,
Platform: "darwin,linux",
},
},
),
)
// create hosts with different platforms, for that team
const hostWin, hostMac, hostDeb, hostLin = 0, 1, 2, 3
platforms := []string{"windows", "darwin", "debian", "linux"}
teamHosts := make([]*fleet.Host, len(platforms))
for i, pl := range platforms {
id := fmt.Sprintf("%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i)
h, err := ds.NewHost(
ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: pl,
TeamID: ptr.Uint(team1.ID),
},
)
require.NoError(t, err)
teamHosts[i] = h
}
// create hosts with different platforms, without team
globalHosts := make([]*fleet.Host, len(platforms))
for i, pl := range platforms {
id := fmt.Sprintf("g%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i)
h, err := ds.NewHost(
ctx, &fleet.Host{
OsqueryHostID: &id,
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: &id,
UUID: id,
Hostname: id,
Platform: pl,
},
)
require.NoError(t, err)
globalHosts[i] = h
}
// load the global policies
gPolicies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, gPolicies, 3)
// load the team policies
tPolicies, _, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, tPolicies, 3)
// index the policies by name for easier access in the rest of the test
polsByName := make(map[string]*fleet.Policy, len(gPolicies)+len(tPolicies))
globalPolsByName := make(map[string]*fleet.Policy, len(gPolicies))
for _, pol := range tPolicies {
polsByName[pol.Name] = pol
}
for _, pol := range gPolicies {
globalPolsByName[pol.Name] = pol
polsByName[pol.Name] = pol
}
// record some results for each policy
// Note: we are adding results to hosts that shouldn't have results, based on their platform.
for _, h := range teamHosts {
res := make(map[uint]*bool, len(polsByName))
for _, pol := range polsByName {
res[pol.ID] = ptr.Bool(false)
}
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
require.NoError(t, err)
}
for _, h := range globalHosts {
res := make(map[uint]*bool, len(globalPolsByName))
for _, pol := range globalPolsByName {
res[pol.ID] = ptr.Bool(false)
}
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
require.NoError(t, err)
}
err = ds.UpdateHostPolicyCounts(ctx)
require.NoError(t, err)
// Update host failure counts and ensure they are correct
teamHosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, teamHosts)
require.NoError(t, err)
assert.Equal(t, 6, teamHosts[hostWin].FailingPoliciesCount)
assert.Equal(t, 6, teamHosts[hostMac].FailingPoliciesCount)
assert.Equal(t, 6, teamHosts[hostDeb].FailingPoliciesCount)
assert.Equal(t, 6, teamHosts[hostLin].FailingPoliciesCount)
globalHosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, globalHosts)
require.NoError(t, err)
assert.Equal(t, 3, globalHosts[hostWin].FailingPoliciesCount)
assert.Equal(t, 3, globalHosts[hostMac].FailingPoliciesCount)
assert.Equal(t, 3, globalHosts[hostDeb].FailingPoliciesCount)
assert.Equal(t, 3, globalHosts[hostLin].FailingPoliciesCount)
// Ensure policy passing and failing counts are correct
gPolicies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, gPolicies, 3)
tPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, tPolicies, 3)
for _, pol := range gPolicies {
polsByName[pol.Name] = pol
}
for _, pol := range tPolicies {
polsByName[pol.Name] = pol
}
assert.Equal(t, uint(8), polsByName[globalNames[0]].FailingHostCount)
assert.Equal(t, uint(8), polsByName[globalNames[1]].FailingHostCount)
assert.Equal(t, uint(8), polsByName[globalNames[2]].FailingHostCount)
assert.Equal(t, uint(4), polsByName[teamNames[0]].FailingHostCount)
assert.Equal(t, uint(4), polsByName[teamNames[1]].FailingHostCount)
assert.Equal(t, uint(4), polsByName[teamNames[2]].FailingHostCount)
// Update policies
require.NoError(
t, ds.ApplyPolicySpecs(
ctx, user1.ID, []*fleet.PolicySpec{
{
Name: globalNames[0],
Query: "select 1;",
Team: "",
Platform: "",
Description: "updated", // update description
},
{
Name: globalNames[1],
Query: "select 2 updated;", // update query
Team: "",
Platform: "darwin",
},
{
Name: globalNames[2],
Query: "select 3;",
Team: "",
Platform: "darwin", // update platform
},
{
Name: "new global query",
Query: "select 4;",
Team: "",
Platform: "",
},
{
Name: teamNames[0],
Query: "select 1;",
Team: "team1" + unicode,
Platform: "linux", // update platform
},
{
Name: teamNames[1],
Query: "select 2;",
Team: "team1" + unicode,
Platform: "darwin",
CalendarEventsEnabled: true, // update calendar events
},
{
Name: teamNames[2],
Query: "select 3 updated;", // update query
Team: "team1" + unicodeEq,
Platform: "darwin,linux",
},
{
Name: "new team query",
Query: "select 4;",
Team: "team1" + unicode,
Platform: "",
},
},
),
)
// Update host failure counts and ensure they are correct
teamHosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, teamHosts)
require.NoError(t, err)
assert.Equal(t, 1, teamHosts[hostWin].FailingPoliciesCount) // kept result from globalNames[0]
assert.Equal(t, 3, teamHosts[hostMac].FailingPoliciesCount)
assert.Equal(t, 2, teamHosts[hostDeb].FailingPoliciesCount)
assert.Equal(t, 2, teamHosts[hostLin].FailingPoliciesCount)
globalHosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, globalHosts)
require.NoError(t, err)
assert.Equal(t, 1, globalHosts[hostWin].FailingPoliciesCount)
assert.Equal(t, 2, globalHosts[hostMac].FailingPoliciesCount)
assert.Equal(t, 1, globalHosts[hostDeb].FailingPoliciesCount)
assert.Equal(t, 1, globalHosts[hostLin].FailingPoliciesCount)
// Ensure policy passing and failing counts are correct
gPolicies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, gPolicies, 4)
tPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, tPolicies, 4)
for _, pol := range gPolicies {
polsByName[pol.Name] = pol
}
for _, pol := range tPolicies {
polsByName[pol.Name] = pol
}
assert.Equal(t, uint(8), polsByName[globalNames[0]].FailingHostCount)
assert.Equal(t, uint(0), polsByName[globalNames[1]].FailingHostCount) // updated query
assert.Equal(t, uint(0), polsByName[globalNames[2]].FailingHostCount) // updated platform
assert.Equal(t, uint(0), polsByName[teamNames[0]].FailingHostCount) // updated platform
assert.Equal(t, uint(4), polsByName[teamNames[1]].FailingHostCount)
assert.Equal(t, uint(0), polsByName[teamNames[2]].FailingHostCount) // updated query
err = ds.UpdateHostPolicyCounts(ctx)
require.NoError(t, err)
gPolicies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, gPolicies, 4)
tPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{})
require.NoError(t, err)
require.Len(t, tPolicies, 4)
for _, pol := range gPolicies {
polsByName[pol.Name] = pol
}
for _, pol := range tPolicies {
polsByName[pol.Name] = pol
}
assert.Equal(t, uint(8), polsByName[globalNames[0]].FailingHostCount) // platform is "" -- no change
assert.Equal(t, uint(0), polsByName[globalNames[1]].FailingHostCount) // updated query
assert.Equal(t, uint(2), polsByName[globalNames[2]].FailingHostCount) // updated platform
assert.Equal(t, uint(2), polsByName[teamNames[0]].FailingHostCount) // updated platform
assert.Equal(t, uint(1), polsByName[teamNames[1]].FailingHostCount) // platform is "darwin" -- no change
assert.Equal(t, uint(0), polsByName[teamNames[2]].FailingHostCount) // updated query
}
func testPoliciesSave(t *testing.T, ds *Datastore) {
user1 := test.NewUser(t, ds, "User1", "user1@example.com", true)
ctx := context.Background()
@ -1412,7 +1705,8 @@ func testPoliciesSave(t *testing.T, ds *Datastore) {
Name: "non-existent query",
Query: "select 1;",
},
}, false)
}, false, false,
)
require.Error(t, err)
var nfe *notFoundError
require.True(t, errors.As(err, &nfe))
@ -1473,7 +1767,7 @@ func testPoliciesSave(t *testing.T, ds *Datastore) {
gp2 := *gp
gp2.Name = "global query updated"
gp2.Critical = true
err = ds.SavePolicy(ctx, &gp2, false)
err = ds.SavePolicy(ctx, &gp2, false, false)
require.NoError(t, err)
gp, err = ds.Policy(ctx, gp.ID)
require.NoError(t, err)
@ -1493,7 +1787,7 @@ func testPoliciesSave(t *testing.T, ds *Datastore) {
tp2.Resolution = ptr.String("team1 query resolution updated")
tp2.Critical = false
tp2.CalendarEventsEnabled = false
err = ds.SavePolicy(ctx, &tp2, true)
err = ds.SavePolicy(ctx, &tp2, true, true)
require.NoError(t, err)
tp1, err = ds.Policy(ctx, tp1.ID)
tp2.UpdateCreateTimestamps = tp1.UpdateCreateTimestamps
@ -1586,7 +1880,7 @@ func testCachedPolicyCountDeletesOnPolicyChange(t *testing.T, ds *Datastore) {
assert.Equal(t, uint(1), inheritedPolicies[0].PassingHostCount)
// Update the global policy sql to trigger a cache invalidation
err = ds.SavePolicy(ctx, globalPolicy, true)
err = ds.SavePolicy(ctx, globalPolicy, true, true)
require.NoError(t, err)
globalPolicy, err = ds.Policy(ctx, globalPolicy.ID)
@ -1599,8 +1893,8 @@ func testCachedPolicyCountDeletesOnPolicyChange(t *testing.T, ds *Datastore) {
assert.Equal(t, uint(1), teamPolicies[0].PassingHostCount)
assert.Equal(t, uint(0), inheritedPolicies[0].PassingHostCount)
// Update the team policy sql to trigger a cache invalidation
err = ds.SavePolicy(ctx, teamPolicy, true)
// Update the team policy platform to trigger a cache invalidation
err = ds.SavePolicy(ctx, teamPolicy, false, true)
require.NoError(t, err)
teamPolicies, inheritedPolicies, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{})
@ -1921,9 +2215,9 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
}
// updating without change works fine
err = ds.SavePolicy(ctx, polsByName["g1"], false)
err = ds.SavePolicy(ctx, polsByName["g1"], false, false)
require.NoError(t, err)
err = ds.SavePolicy(ctx, polsByName["t2"], false)
err = ds.SavePolicy(ctx, polsByName["t2"], false, false)
require.NoError(t, err)
// apply specs that result in an update (without change) works fine
err = ds.ApplyPolicySpecs(ctx, user.ID, []*fleet.PolicySpec{
@ -1975,7 +2269,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
g1 := polsByName["g1"]
g1.Platform = "linux"
polsByName["g1"] = g1
err = ds.SavePolicy(ctx, g1, false)
err = ds.SavePolicy(ctx, g1, false, false)
require.NoError(t, err)
wantHostsByPol["g1"] = []uint{globalHosts[hostDeb].ID, globalHosts[hostLin].ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
@ -1984,7 +2278,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
t1 := polsByName["t1"]
t1.Platform = "windows,darwin"
polsByName["t1"] = t1
err = ds.SavePolicy(ctx, t1, false)
err = ds.SavePolicy(ctx, t1, false, false)
require.NoError(t, err)
wantHostsByPol["t1"] = []uint{teamHosts[hostWin].ID, teamHosts[hostMac].ID}
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
@ -2723,7 +3017,7 @@ func testPoliciesNameUnicode(t *testing.T, ds *Datastore) {
policyEmoji, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "💻"})
require.NoError(t, err)
err = ds.SavePolicy(
context.Background(), &fleet.Policy{PolicyData: fleet.PolicyData{ID: policyEmoji.ID, Name: equivalentNames[1]}}, false,
context.Background(), &fleet.Policy{PolicyData: fleet.PolicyData{ID: policyEmoji.ID, Name: equivalentNames[1]}}, false, false,
)
assert.True(t, isDuplicate(err), err)
@ -3270,10 +3564,10 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) {
//
team2Policy1.Platform = "darwin"
err = ds.SavePolicy(ctx, team1Policy1, false)
err = ds.SavePolicy(ctx, team1Policy1, false, true)
require.NoError(t, err)
team1Policy1.Platform = "darwin"
err = ds.SavePolicy(ctx, team2Policy1, false)
err = ds.SavePolicy(ctx, team2Policy1, false, true)
require.NoError(t, err)
//

View file

@ -602,7 +602,7 @@ type Datastore interface {
// SavePolicy updates some fields of the given policy on the datastore.
//
// It is also used to update team policies.
SavePolicy(ctx context.Context, p *Policy, shouldRemoveAllPolicyMemberships bool) error
SavePolicy(ctx context.Context, p *Policy, shouldRemoveAllPolicyMemberships bool, removePolicyStats bool) error
ListGlobalPolicies(ctx context.Context, opts ListOptions) ([]*Policy, error)
PoliciesByID(ctx context.Context, ids []uint) (map[uint]*Policy, error)

View file

@ -435,7 +435,7 @@ type NewGlobalPolicyFunc func(ctx context.Context, authorID *uint, args fleet.Po
type PolicyFunc func(ctx context.Context, id uint) (*fleet.Policy, error)
type SavePolicyFunc func(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error
type SavePolicyFunc func(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool, removePolicyStats bool) error
type ListGlobalPoliciesFunc func(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error)
@ -3729,11 +3729,11 @@ func (s *DataStore) Policy(ctx context.Context, id uint) (*fleet.Policy, error)
return s.PolicyFunc(ctx, id)
}
func (s *DataStore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error {
func (s *DataStore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool, removePolicyStats bool) error {
s.mu.Lock()
s.SavePolicyFuncInvoked = true
s.mu.Unlock()
return s.SavePolicyFunc(ctx, p, shouldRemoveAllPolicyMemberships)
return s.SavePolicyFunc(ctx, p, shouldRemoveAllPolicyMemberships, removePolicyStats)
}
func (s *DataStore) ListGlobalPolicies(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error) {

View file

@ -68,7 +68,7 @@ func TestGlobalPoliciesAuth(t *testing.T) {
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error {
return nil
}
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool) error {
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool, removePolicyStats bool) error {
return nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {

View file

@ -2428,6 +2428,65 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount)
assert.Equal(t, uint(0), policiesResponse.Policies[0].PassingHostCount)
// Record query executions
require.NoError(
t, s.ds.RecordPolicyQueryExecutions(
context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false,
),
)
require.NoError(
t, s.ds.RecordPolicyQueryExecutions(
context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false,
),
)
// Update policy stats
require.NoError(t, s.ds.UpdateHostPolicyCounts(context.Background()))
// Fetch policy to make sure stats are updated
s.DoJSON("GET", "/api/latest/fleet/policies", nil, http.StatusOK, &policiesResponse)
require.Len(t, policiesResponse.Policies, 1)
assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount)
assert.Equal(t, uint(1), policiesResponse.Policies[0].PassingHostCount)
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID)
listHostsResp = listHostsResponse{}
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
require.Len(t, listHostsResp.Hosts, 1)
// Modify the platform for the policy, which should clear the policy stats
mgpParams = modifyGlobalPolicyRequest{
ModifyPolicyPayload: fleet.ModifyPolicyPayload{
Platform: ptr.String("linux"),
},
}
mgpResp = modifyGlobalPolicyResponse{}
s.DoJSON("PATCH", fmt.Sprintf("/api/latest/fleet/policies/%d", gpResp.Policy.ID), mgpParams, http.StatusOK, &mgpResp)
require.NotNil(t, gpResp.Policy)
assert.Equal(t, "TestQuery4", mgpResp.Policy.Name)
assert.Equal(t, "select * from users;", mgpResp.Policy.Query)
assert.Equal(t, "Some description updated", mgpResp.Policy.Description)
require.NotNil(t, mgpResp.Policy.Resolution)
assert.Equal(t, "some global resolution updated", *mgpResp.Policy.Resolution)
assert.Equal(t, "linux", mgpResp.Policy.Platform)
assert.Equal(t, uint(0), mgpResp.Policy.FailingHostCount)
assert.Equal(t, uint(0), mgpResp.Policy.PassingHostCount)
// Fetch policy to make sure stats are updated
s.DoJSON("GET", "/api/latest/fleet/policies", nil, http.StatusOK, &policiesResponse)
require.Len(t, policiesResponse.Policies, 1)
assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount)
assert.Equal(t, uint(0), policiesResponse.Policies[0].PassingHostCount)
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID)
listHostsResp = listHostsResponse{}
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
require.Len(t, listHostsResp.Hosts, 0)
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=failing", policiesResponse.Policies[0].ID)
listHostsResp = listHostsResponse{}
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
require.Len(t, listHostsResp.Hosts, 0)
deletePolicyParams := deleteGlobalPoliciesRequest{IDs: []uint{policiesResponse.Policies[0].ID}}
deletePolicyResp := deleteGlobalPoliciesResponse{}
s.DoJSON("POST", "/api/latest/fleet/policies/delete", deletePolicyParams, http.StatusOK, &deletePolicyResp)

View file

@ -368,7 +368,8 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
})
}
var shouldRemoveAll bool
var removeAllMemberships bool
var removeStats bool
if p.Name != nil {
policy.Name = *p.Name
}
@ -377,9 +378,8 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
}
if p.Query != nil {
if policy.Query != *p.Query {
shouldRemoveAll = true
policy.FailingHostCount = 0
policy.PassingHostCount = 0
removeAllMemberships = true
removeStats = true
}
policy.Query = *p.Query
}
@ -387,6 +387,9 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
policy.Resolution = p.Resolution
}
if p.Platform != nil {
if policy.Platform != *p.Platform {
removeStats = true
}
policy.Platform = *p.Platform
}
if p.Critical != nil {
@ -395,9 +398,13 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
if p.CalendarEventsEnabled != nil {
policy.CalendarEventsEnabled = *p.CalendarEventsEnabled
}
if removeStats {
policy.FailingHostCount = 0
policy.PassingHostCount = 0
}
logging.WithExtras(ctx, "name", policy.Name, "sql", policy.Query)
err = svc.ds.SavePolicy(ctx, policy, shouldRemoveAll)
err = svc.ds.SavePolicy(ctx, policy, removeAllMemberships, removeStats)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "saving policy")
}

View file

@ -44,7 +44,7 @@ func TestTeamPoliciesAuth(t *testing.T) {
}
return nil, nil
}
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool) error {
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool, removePolicyStats bool) error {
return nil
}
ds.DeleteTeamPoliciesFunc = func(ctx context.Context, teamID uint, ids []uint) ([]uint, error) {