mirror of
https://github.com/fleetdm/fleet
synced 2026-05-24 09:28:54 +00:00
When updating a policy's 'platform' field, the aggregated policy stats are now cleared. (#18415)
#18157 When updating a policy's 'platform' field, the aggregated policy stats are now cleared. # Checklist for submitter - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [x] Added/updated tests - [x] Manual QA for all new/changed functionality
This commit is contained in:
parent
0b66bc4a9b
commit
d0f0d3d017
9 changed files with 527 additions and 58 deletions
1
changes/18157-update-platform-policy-stats
Normal file
1
changes/18157-update-platform-policy-stats
Normal file
|
|
@ -0,0 +1 @@
|
|||
When updating a policy's 'platform' field, the aggregated policy stats are now cleared.
|
||||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
|
@ -14,6 +15,7 @@ import (
|
|||
"github.com/doug-martin/goqu/v9"
|
||||
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
||||
"github.com/fleetdm/fleet/v4/server/fleet"
|
||||
kitlog "github.com/go-kit/kit/log"
|
||||
"github.com/go-kit/kit/log/level"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
|
@ -110,8 +112,8 @@ func policyDB(ctx context.Context, q sqlx.QueryerContext, id uint, teamID *uint)
|
|||
|
||||
// SavePolicy updates some fields of the given policy on the datastore.
|
||||
//
|
||||
// Currently SavePolicy does not allow updating the team of an existing policy.
|
||||
func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error {
|
||||
// Currently, SavePolicy does not allow updating the team of an existing policy.
|
||||
func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool, removePolicyStats bool) error {
|
||||
// We must normalize the name for full Unicode support (Unicode equivalence).
|
||||
p.Name = norm.NFC.String(p.Name)
|
||||
sql := `
|
||||
|
|
@ -133,10 +135,39 @@ func (ds *Datastore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemo
|
|||
return ctxerr.Wrap(ctx, notFound("Policy").WithID(p.ID))
|
||||
}
|
||||
|
||||
return cleanupPolicy(ctx, ds.writer(ctx), p.ID, p.Platform, shouldRemoveAllPolicyMemberships, removePolicyStats, ds.logger)
|
||||
}
|
||||
|
||||
func cleanupPolicy(
|
||||
ctx context.Context, extContext sqlx.ExtContext, policyID uint, policyPlatform string, shouldRemoveAllPolicyMemberships bool,
|
||||
removePolicyStats bool, logger kitlog.Logger,
|
||||
) error {
|
||||
var err error
|
||||
if shouldRemoveAllPolicyMemberships {
|
||||
return ds.cleanupPolicyMembershipForPolicy(ctx, p.ID)
|
||||
err = cleanupPolicyMembershipForPolicy(ctx, extContext, policyID)
|
||||
} else {
|
||||
err = cleanupPolicyMembershipOnPolicyUpdate(ctx, extContext, policyID, policyPlatform)
|
||||
}
|
||||
return cleanupPolicyMembershipOnPolicyUpdate(ctx, ds.writer(ctx), p.ID, p.Platform)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if removePolicyStats {
|
||||
// delete all policy stats for the policy
|
||||
fn := func(tx sqlx.ExtContext) error {
|
||||
_, err := tx.ExecContext(ctx, `DELETE FROM policy_stats WHERE policy_id = ?`, policyID)
|
||||
return err
|
||||
}
|
||||
if _, isDB := extContext.(*sqlx.DB); isDB {
|
||||
// wrapping in a retry to avoid deadlocks with the cleanups_then_aggregation cron job
|
||||
err = withRetryTxx(ctx, extContext.(*sqlx.DB), fn, logger)
|
||||
} else {
|
||||
err = fn(extContext)
|
||||
}
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "cleanup policy stats")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// FlippingPoliciesForHost fetches previous policy membership results and returns:
|
||||
|
|
@ -576,8 +607,74 @@ func (ds *Datastore) TeamPolicy(ctx context.Context, teamID uint, policyID uint)
|
|||
// NOTE: Similar to ApplyQueries, ApplyPolicySpecs will update the author_id of the policies
|
||||
// that are updated.
|
||||
//
|
||||
// Currently ApplyPolicySpecs does not allow updating the team of an existing policy.
|
||||
// Currently, ApplyPolicySpecs does not allow updating the team of an existing policy.
|
||||
func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs []*fleet.PolicySpec) error {
|
||||
// Use the same DB for all operations in this method for performance
|
||||
queryerContext := ds.writer(ctx)
|
||||
|
||||
// Preprocess specs and group them by team
|
||||
teamNameToID := make(map[string]uint, 1)
|
||||
teamIDToPolicies := make(map[uint][]*fleet.PolicySpec, 1)
|
||||
|
||||
// Get the team IDs
|
||||
for _, spec := range specs {
|
||||
// We must normalize the name for full Unicode support (Unicode equivalence).
|
||||
spec.Name = norm.NFC.String(spec.Name)
|
||||
spec.Team = norm.NFC.String(spec.Team)
|
||||
teamID, ok := teamNameToID[spec.Team]
|
||||
if !ok {
|
||||
if spec.Team != "" {
|
||||
// if team name is not empty, it must have a team ID; otherwise teamID defaults to 0 value
|
||||
err := sqlx.GetContext(ctx, queryerContext, &teamID, `SELECT id FROM teams WHERE name = ?`, spec.Team)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return ctxerr.Wrap(ctx, notFound("Team").WithName(spec.Team), "get team id")
|
||||
}
|
||||
return ctxerr.Wrap(ctx, err, "get team id")
|
||||
}
|
||||
}
|
||||
teamNameToID[spec.Team] = teamID
|
||||
}
|
||||
teamIDToPolicies[teamID] = append(teamIDToPolicies[teamID], spec)
|
||||
}
|
||||
|
||||
// Get the query and platforms of the current policies so that we can check if query or platform changed later, if needed
|
||||
type policyLite struct {
|
||||
Name string `db:"name"`
|
||||
Query string `db:"query"`
|
||||
Platforms string `db:"platforms"`
|
||||
}
|
||||
teamIDToPoliciesByName := make(map[uint]map[string]policyLite, len(teamIDToPolicies))
|
||||
for teamID, teamPolicySpecs := range teamIDToPolicies {
|
||||
teamIDToPoliciesByName[teamID] = make(map[string]policyLite, len(teamPolicySpecs))
|
||||
policyNames := make([]string, 0, len(teamPolicySpecs))
|
||||
for _, spec := range teamPolicySpecs {
|
||||
policyNames = append(policyNames, spec.Name)
|
||||
}
|
||||
|
||||
var query string
|
||||
var args []interface{}
|
||||
var err error
|
||||
if teamID == 0 {
|
||||
query, args, err = sqlx.In("SELECT name, query, platforms FROM policies WHERE team_id IS NULL AND name IN (?)", policyNames)
|
||||
} else {
|
||||
query, args, err = sqlx.In(
|
||||
"SELECT name, query, platforms FROM policies WHERE team_id = ? AND name IN (?)", &teamID, policyNames,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "building query to get policies by name")
|
||||
}
|
||||
policies := make([]policyLite, 0, len(teamPolicySpecs))
|
||||
err = sqlx.SelectContext(ctx, queryerContext, &policies, query, args...)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "getting policies by name")
|
||||
}
|
||||
for _, p := range policies {
|
||||
teamIDToPoliciesByName[teamID][p.Name] = p
|
||||
}
|
||||
}
|
||||
|
||||
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
query := fmt.Sprintf(
|
||||
`
|
||||
|
|
@ -592,7 +689,7 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs
|
|||
critical,
|
||||
calendar_events_enabled,
|
||||
checksum
|
||||
) VALUES ( ?, ?, ?, ?, ?, (SELECT IFNULL(MIN(id), NULL) FROM teams WHERE name = ?), ?, ?, ?, %s)
|
||||
) VALUES ( ?, ?, ?, ?, ?, ?, ?, ?, ?, %s)
|
||||
ON DUPLICATE KEY UPDATE
|
||||
query = VALUES(query),
|
||||
description = VALUES(description),
|
||||
|
|
@ -603,24 +700,45 @@ func (ds *Datastore) ApplyPolicySpecs(ctx context.Context, authorID uint, specs
|
|||
calendar_events_enabled = VALUES(calendar_events_enabled)
|
||||
`, policiesChecksumComputedColumn(),
|
||||
)
|
||||
for _, spec := range specs {
|
||||
|
||||
// We must normalize the name for full Unicode support (Unicode equivalence).
|
||||
spec.Name = norm.NFC.String(spec.Name)
|
||||
res, err := tx.ExecContext(ctx,
|
||||
query, spec.Name, spec.Query, spec.Description, authorID, spec.Resolution, spec.Team, spec.Platform, spec.Critical,
|
||||
spec.CalendarEventsEnabled,
|
||||
)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "exec ApplyPolicySpecs insert")
|
||||
for teamID, teamPolicySpecs := range teamIDToPolicies {
|
||||
var teamIDPtr *uint
|
||||
if teamID != 0 {
|
||||
teamIDPtr = &teamID
|
||||
}
|
||||
for _, spec := range teamPolicySpecs {
|
||||
|
||||
if insertOnDuplicateDidUpdate(res) {
|
||||
// when the upsert results in an UPDATE that *did* change some values,
|
||||
// it returns the updated ID as last inserted id.
|
||||
if lastID, _ := res.LastInsertId(); lastID > 0 {
|
||||
if err := cleanupPolicyMembershipOnPolicyUpdate(ctx, tx, uint(lastID), spec.Platform); err != nil {
|
||||
return err
|
||||
res, err := tx.ExecContext(
|
||||
ctx,
|
||||
query, spec.Name, spec.Query, spec.Description, authorID, spec.Resolution, teamIDPtr, spec.Platform, spec.Critical,
|
||||
spec.CalendarEventsEnabled,
|
||||
)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "exec ApplyPolicySpecs insert")
|
||||
}
|
||||
|
||||
if insertOnDuplicateDidUpdate(res) {
|
||||
// when the upsert results in an UPDATE that *did* change some values,
|
||||
// it returns the updated ID as last inserted id.
|
||||
if lastID, _ := res.LastInsertId(); lastID > 0 {
|
||||
var (
|
||||
shouldRemoveAllPolicyMemberships bool
|
||||
removePolicyStats bool
|
||||
)
|
||||
// Figure out if the query or platform changed
|
||||
if prev, ok := teamIDToPoliciesByName[teamID][spec.Name]; ok {
|
||||
switch {
|
||||
case prev.Query != spec.Query:
|
||||
shouldRemoveAllPolicyMemberships = true
|
||||
removePolicyStats = true
|
||||
case prev.Platforms != spec.Platform:
|
||||
removePolicyStats = true
|
||||
}
|
||||
}
|
||||
if err = cleanupPolicy(
|
||||
ctx, tx, uint(lastID), spec.Platform, shouldRemoveAllPolicyMemberships, removePolicyStats, ds.logger,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -739,7 +857,7 @@ func cleanupPolicyMembershipOnPolicyUpdate(ctx context.Context, db sqlx.ExecerCo
|
|||
|
||||
// cleanupPolicyMembership is similar to cleanupPolicyMembershipOnPolicyUpdate but without the platform constraints.
|
||||
// Used when we want to remove all policy membership.
|
||||
func (ds *Datastore) cleanupPolicyMembershipForPolicy(ctx context.Context, policyID uint) error {
|
||||
func cleanupPolicyMembershipForPolicy(ctx context.Context, exec sqlx.ExecerContext, policyID uint) error {
|
||||
// delete all policy memberships for the policy
|
||||
delStmt := `
|
||||
DELETE
|
||||
|
|
@ -754,21 +872,11 @@ func (ds *Datastore) cleanupPolicyMembershipForPolicy(ctx context.Context, polic
|
|||
pm.policy_id = ?
|
||||
`
|
||||
|
||||
_, err := ds.writer(ctx).ExecContext(ctx, delStmt, policyID)
|
||||
_, err := exec.ExecContext(ctx, delStmt, policyID)
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "cleanup policy membership")
|
||||
}
|
||||
|
||||
// delete all policy stats for the policy
|
||||
// wrapping in a retry to avoid deadlocks with the cleanups_then_aggregation cron job
|
||||
err = ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
_, err := tx.ExecContext(ctx, `DELETE FROM policy_stats WHERE policy_id = ?`, policyID)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return ctxerr.Wrap(ctx, err, "cleanup policy stats")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ func TestPolicies(t *testing.T) {
|
|||
{"PoliciesByID", testPoliciesByID},
|
||||
{"TeamPolicyTransfer", testTeamPolicyTransfer},
|
||||
{"ApplyPolicySpec", testApplyPolicySpec},
|
||||
{"ApplyPolicySpecWithQueryPlatformChanges", testApplyPolicySpecWithQueryPlatformChanges},
|
||||
{"Save", testPoliciesSave},
|
||||
{"DelUser", testPoliciesDelUser},
|
||||
{"FlippingPoliciesForHost", testFlippingPoliciesForHost},
|
||||
|
|
@ -1400,6 +1401,298 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) {
|
|||
}))
|
||||
}
|
||||
|
||||
func testApplyPolicySpecWithQueryPlatformChanges(t *testing.T, ds *Datastore) {
|
||||
ctx := context.Background()
|
||||
unicode, _ := strconv.Unquote(`"\uAC00"`) // 가
|
||||
unicodeEq, _ := strconv.Unquote(`"\u1100\u1161"`) // ᄀ + ᅡ
|
||||
|
||||
user1 := test.NewUser(t, ds, "User1", "user1@example.com", true)
|
||||
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1" + unicode})
|
||||
require.NoError(t, err)
|
||||
|
||||
globalNames := []string{"global query1" + unicode, "global query2" + unicode, "global query3" + unicode}
|
||||
teamNames := []string{"team query1", "team query2", "team query3"}
|
||||
require.NoError(
|
||||
t, ds.ApplyPolicySpecs(
|
||||
ctx, user1.ID, []*fleet.PolicySpec{
|
||||
{
|
||||
Name: globalNames[0],
|
||||
Query: "select 1;",
|
||||
Team: "",
|
||||
Platform: "",
|
||||
},
|
||||
{
|
||||
Name: globalNames[1],
|
||||
Query: "select 2;",
|
||||
Team: "",
|
||||
Platform: "darwin",
|
||||
},
|
||||
{
|
||||
Name: globalNames[2],
|
||||
Query: "select 3;",
|
||||
Team: "",
|
||||
Platform: "darwin,linux",
|
||||
},
|
||||
{
|
||||
Name: teamNames[0],
|
||||
Query: "select 1;",
|
||||
Team: "team1" + unicode,
|
||||
Platform: "",
|
||||
},
|
||||
{
|
||||
Name: teamNames[1],
|
||||
Query: "select 2;",
|
||||
Team: "team1" + unicode,
|
||||
Platform: "darwin",
|
||||
},
|
||||
{
|
||||
Name: teamNames[2],
|
||||
Query: "select 3;",
|
||||
Team: "team1" + unicodeEq,
|
||||
Platform: "darwin,linux",
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
// create hosts with different platforms, for that team
|
||||
const hostWin, hostMac, hostDeb, hostLin = 0, 1, 2, 3
|
||||
platforms := []string{"windows", "darwin", "debian", "linux"}
|
||||
teamHosts := make([]*fleet.Host, len(platforms))
|
||||
for i, pl := range platforms {
|
||||
id := fmt.Sprintf("%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i)
|
||||
h, err := ds.NewHost(
|
||||
ctx, &fleet.Host{
|
||||
OsqueryHostID: &id,
|
||||
DetailUpdatedAt: time.Now(),
|
||||
LabelUpdatedAt: time.Now(),
|
||||
PolicyUpdatedAt: time.Now(),
|
||||
SeenTime: time.Now(),
|
||||
NodeKey: &id,
|
||||
UUID: id,
|
||||
Hostname: id,
|
||||
Platform: pl,
|
||||
TeamID: ptr.Uint(team1.ID),
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
teamHosts[i] = h
|
||||
}
|
||||
|
||||
// create hosts with different platforms, without team
|
||||
globalHosts := make([]*fleet.Host, len(platforms))
|
||||
for i, pl := range platforms {
|
||||
id := fmt.Sprintf("g%s-%d", strings.ReplaceAll(t.Name(), "/", "_"), i)
|
||||
h, err := ds.NewHost(
|
||||
ctx, &fleet.Host{
|
||||
OsqueryHostID: &id,
|
||||
DetailUpdatedAt: time.Now(),
|
||||
LabelUpdatedAt: time.Now(),
|
||||
PolicyUpdatedAt: time.Now(),
|
||||
SeenTime: time.Now(),
|
||||
NodeKey: &id,
|
||||
UUID: id,
|
||||
Hostname: id,
|
||||
Platform: pl,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
globalHosts[i] = h
|
||||
}
|
||||
|
||||
// load the global policies
|
||||
gPolicies, err := ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, gPolicies, 3)
|
||||
// load the team policies
|
||||
tPolicies, _, err := ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tPolicies, 3)
|
||||
|
||||
// index the policies by name for easier access in the rest of the test
|
||||
polsByName := make(map[string]*fleet.Policy, len(gPolicies)+len(tPolicies))
|
||||
globalPolsByName := make(map[string]*fleet.Policy, len(gPolicies))
|
||||
for _, pol := range tPolicies {
|
||||
polsByName[pol.Name] = pol
|
||||
}
|
||||
for _, pol := range gPolicies {
|
||||
globalPolsByName[pol.Name] = pol
|
||||
polsByName[pol.Name] = pol
|
||||
}
|
||||
|
||||
// record some results for each policy
|
||||
// Note: we are adding results to hosts that shouldn't have results, based on their platform.
|
||||
for _, h := range teamHosts {
|
||||
res := make(map[uint]*bool, len(polsByName))
|
||||
for _, pol := range polsByName {
|
||||
res[pol.ID] = ptr.Bool(false)
|
||||
}
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for _, h := range globalHosts {
|
||||
res := make(map[uint]*bool, len(globalPolsByName))
|
||||
for _, pol := range globalPolsByName {
|
||||
res[pol.ID] = ptr.Bool(false)
|
||||
}
|
||||
err = ds.RecordPolicyQueryExecutions(ctx, h, res, time.Now(), false)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
err = ds.UpdateHostPolicyCounts(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update host failure counts and ensure they are correct
|
||||
teamHosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, teamHosts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 6, teamHosts[hostWin].FailingPoliciesCount)
|
||||
assert.Equal(t, 6, teamHosts[hostMac].FailingPoliciesCount)
|
||||
assert.Equal(t, 6, teamHosts[hostDeb].FailingPoliciesCount)
|
||||
assert.Equal(t, 6, teamHosts[hostLin].FailingPoliciesCount)
|
||||
globalHosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, globalHosts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 3, globalHosts[hostWin].FailingPoliciesCount)
|
||||
assert.Equal(t, 3, globalHosts[hostMac].FailingPoliciesCount)
|
||||
assert.Equal(t, 3, globalHosts[hostDeb].FailingPoliciesCount)
|
||||
assert.Equal(t, 3, globalHosts[hostLin].FailingPoliciesCount)
|
||||
|
||||
// Ensure policy passing and failing counts are correct
|
||||
gPolicies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, gPolicies, 3)
|
||||
tPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tPolicies, 3)
|
||||
|
||||
for _, pol := range gPolicies {
|
||||
polsByName[pol.Name] = pol
|
||||
}
|
||||
for _, pol := range tPolicies {
|
||||
polsByName[pol.Name] = pol
|
||||
}
|
||||
assert.Equal(t, uint(8), polsByName[globalNames[0]].FailingHostCount)
|
||||
assert.Equal(t, uint(8), polsByName[globalNames[1]].FailingHostCount)
|
||||
assert.Equal(t, uint(8), polsByName[globalNames[2]].FailingHostCount)
|
||||
assert.Equal(t, uint(4), polsByName[teamNames[0]].FailingHostCount)
|
||||
assert.Equal(t, uint(4), polsByName[teamNames[1]].FailingHostCount)
|
||||
assert.Equal(t, uint(4), polsByName[teamNames[2]].FailingHostCount)
|
||||
|
||||
// Update policies
|
||||
require.NoError(
|
||||
t, ds.ApplyPolicySpecs(
|
||||
ctx, user1.ID, []*fleet.PolicySpec{
|
||||
{
|
||||
Name: globalNames[0],
|
||||
Query: "select 1;",
|
||||
Team: "",
|
||||
Platform: "",
|
||||
Description: "updated", // update description
|
||||
},
|
||||
{
|
||||
Name: globalNames[1],
|
||||
Query: "select 2 updated;", // update query
|
||||
Team: "",
|
||||
Platform: "darwin",
|
||||
},
|
||||
{
|
||||
Name: globalNames[2],
|
||||
Query: "select 3;",
|
||||
Team: "",
|
||||
Platform: "darwin", // update platform
|
||||
},
|
||||
{
|
||||
Name: "new global query",
|
||||
Query: "select 4;",
|
||||
Team: "",
|
||||
Platform: "",
|
||||
},
|
||||
{
|
||||
Name: teamNames[0],
|
||||
Query: "select 1;",
|
||||
Team: "team1" + unicode,
|
||||
Platform: "linux", // update platform
|
||||
},
|
||||
{
|
||||
Name: teamNames[1],
|
||||
Query: "select 2;",
|
||||
Team: "team1" + unicode,
|
||||
Platform: "darwin",
|
||||
CalendarEventsEnabled: true, // update calendar events
|
||||
},
|
||||
{
|
||||
Name: teamNames[2],
|
||||
Query: "select 3 updated;", // update query
|
||||
Team: "team1" + unicodeEq,
|
||||
Platform: "darwin,linux",
|
||||
},
|
||||
{
|
||||
Name: "new team query",
|
||||
Query: "select 4;",
|
||||
Team: "team1" + unicode,
|
||||
Platform: "",
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
// Update host failure counts and ensure they are correct
|
||||
teamHosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, teamHosts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, teamHosts[hostWin].FailingPoliciesCount) // kept result from globalNames[0]
|
||||
assert.Equal(t, 3, teamHosts[hostMac].FailingPoliciesCount)
|
||||
assert.Equal(t, 2, teamHosts[hostDeb].FailingPoliciesCount)
|
||||
assert.Equal(t, 2, teamHosts[hostLin].FailingPoliciesCount)
|
||||
globalHosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, globalHosts)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, globalHosts[hostWin].FailingPoliciesCount)
|
||||
assert.Equal(t, 2, globalHosts[hostMac].FailingPoliciesCount)
|
||||
assert.Equal(t, 1, globalHosts[hostDeb].FailingPoliciesCount)
|
||||
assert.Equal(t, 1, globalHosts[hostLin].FailingPoliciesCount)
|
||||
|
||||
// Ensure policy passing and failing counts are correct
|
||||
gPolicies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, gPolicies, 4)
|
||||
tPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tPolicies, 4)
|
||||
|
||||
for _, pol := range gPolicies {
|
||||
polsByName[pol.Name] = pol
|
||||
}
|
||||
for _, pol := range tPolicies {
|
||||
polsByName[pol.Name] = pol
|
||||
}
|
||||
assert.Equal(t, uint(8), polsByName[globalNames[0]].FailingHostCount)
|
||||
assert.Equal(t, uint(0), polsByName[globalNames[1]].FailingHostCount) // updated query
|
||||
assert.Equal(t, uint(0), polsByName[globalNames[2]].FailingHostCount) // updated platform
|
||||
assert.Equal(t, uint(0), polsByName[teamNames[0]].FailingHostCount) // updated platform
|
||||
assert.Equal(t, uint(4), polsByName[teamNames[1]].FailingHostCount)
|
||||
assert.Equal(t, uint(0), polsByName[teamNames[2]].FailingHostCount) // updated query
|
||||
|
||||
err = ds.UpdateHostPolicyCounts(ctx)
|
||||
require.NoError(t, err)
|
||||
gPolicies, err = ds.ListGlobalPolicies(ctx, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, gPolicies, 4)
|
||||
tPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tPolicies, 4)
|
||||
|
||||
for _, pol := range gPolicies {
|
||||
polsByName[pol.Name] = pol
|
||||
}
|
||||
for _, pol := range tPolicies {
|
||||
polsByName[pol.Name] = pol
|
||||
}
|
||||
assert.Equal(t, uint(8), polsByName[globalNames[0]].FailingHostCount) // platform is "" -- no change
|
||||
assert.Equal(t, uint(0), polsByName[globalNames[1]].FailingHostCount) // updated query
|
||||
assert.Equal(t, uint(2), polsByName[globalNames[2]].FailingHostCount) // updated platform
|
||||
assert.Equal(t, uint(2), polsByName[teamNames[0]].FailingHostCount) // updated platform
|
||||
assert.Equal(t, uint(1), polsByName[teamNames[1]].FailingHostCount) // platform is "darwin" -- no change
|
||||
assert.Equal(t, uint(0), polsByName[teamNames[2]].FailingHostCount) // updated query
|
||||
|
||||
}
|
||||
|
||||
func testPoliciesSave(t *testing.T, ds *Datastore) {
|
||||
user1 := test.NewUser(t, ds, "User1", "user1@example.com", true)
|
||||
ctx := context.Background()
|
||||
|
|
@ -1412,7 +1705,8 @@ func testPoliciesSave(t *testing.T, ds *Datastore) {
|
|||
Name: "non-existent query",
|
||||
Query: "select 1;",
|
||||
},
|
||||
}, false)
|
||||
}, false, false,
|
||||
)
|
||||
require.Error(t, err)
|
||||
var nfe *notFoundError
|
||||
require.True(t, errors.As(err, &nfe))
|
||||
|
|
@ -1473,7 +1767,7 @@ func testPoliciesSave(t *testing.T, ds *Datastore) {
|
|||
gp2 := *gp
|
||||
gp2.Name = "global query updated"
|
||||
gp2.Critical = true
|
||||
err = ds.SavePolicy(ctx, &gp2, false)
|
||||
err = ds.SavePolicy(ctx, &gp2, false, false)
|
||||
require.NoError(t, err)
|
||||
gp, err = ds.Policy(ctx, gp.ID)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -1493,7 +1787,7 @@ func testPoliciesSave(t *testing.T, ds *Datastore) {
|
|||
tp2.Resolution = ptr.String("team1 query resolution updated")
|
||||
tp2.Critical = false
|
||||
tp2.CalendarEventsEnabled = false
|
||||
err = ds.SavePolicy(ctx, &tp2, true)
|
||||
err = ds.SavePolicy(ctx, &tp2, true, true)
|
||||
require.NoError(t, err)
|
||||
tp1, err = ds.Policy(ctx, tp1.ID)
|
||||
tp2.UpdateCreateTimestamps = tp1.UpdateCreateTimestamps
|
||||
|
|
@ -1586,7 +1880,7 @@ func testCachedPolicyCountDeletesOnPolicyChange(t *testing.T, ds *Datastore) {
|
|||
assert.Equal(t, uint(1), inheritedPolicies[0].PassingHostCount)
|
||||
|
||||
// Update the global policy sql to trigger a cache invalidation
|
||||
err = ds.SavePolicy(ctx, globalPolicy, true)
|
||||
err = ds.SavePolicy(ctx, globalPolicy, true, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
globalPolicy, err = ds.Policy(ctx, globalPolicy.ID)
|
||||
|
|
@ -1599,8 +1893,8 @@ func testCachedPolicyCountDeletesOnPolicyChange(t *testing.T, ds *Datastore) {
|
|||
assert.Equal(t, uint(1), teamPolicies[0].PassingHostCount)
|
||||
assert.Equal(t, uint(0), inheritedPolicies[0].PassingHostCount)
|
||||
|
||||
// Update the team policy sql to trigger a cache invalidation
|
||||
err = ds.SavePolicy(ctx, teamPolicy, true)
|
||||
// Update the team policy platform to trigger a cache invalidation
|
||||
err = ds.SavePolicy(ctx, teamPolicy, false, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
teamPolicies, inheritedPolicies, err = ds.ListTeamPolicies(ctx, team1.ID, fleet.ListOptions{}, fleet.ListOptions{})
|
||||
|
|
@ -1921,9 +2215,9 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
|
|||
}
|
||||
|
||||
// updating without change works fine
|
||||
err = ds.SavePolicy(ctx, polsByName["g1"], false)
|
||||
err = ds.SavePolicy(ctx, polsByName["g1"], false, false)
|
||||
require.NoError(t, err)
|
||||
err = ds.SavePolicy(ctx, polsByName["t2"], false)
|
||||
err = ds.SavePolicy(ctx, polsByName["t2"], false, false)
|
||||
require.NoError(t, err)
|
||||
// apply specs that result in an update (without change) works fine
|
||||
err = ds.ApplyPolicySpecs(ctx, user.ID, []*fleet.PolicySpec{
|
||||
|
|
@ -1975,7 +2269,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
|
|||
g1 := polsByName["g1"]
|
||||
g1.Platform = "linux"
|
||||
polsByName["g1"] = g1
|
||||
err = ds.SavePolicy(ctx, g1, false)
|
||||
err = ds.SavePolicy(ctx, g1, false, false)
|
||||
require.NoError(t, err)
|
||||
wantHostsByPol["g1"] = []uint{globalHosts[hostDeb].ID, globalHosts[hostLin].ID}
|
||||
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
|
||||
|
|
@ -1984,7 +2278,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
|
|||
t1 := polsByName["t1"]
|
||||
t1.Platform = "windows,darwin"
|
||||
polsByName["t1"] = t1
|
||||
err = ds.SavePolicy(ctx, t1, false)
|
||||
err = ds.SavePolicy(ctx, t1, false, false)
|
||||
require.NoError(t, err)
|
||||
wantHostsByPol["t1"] = []uint{teamHosts[hostWin].ID, teamHosts[hostMac].ID}
|
||||
assertPolicyMembership(t, ds, polsByName, wantHostsByPol)
|
||||
|
|
@ -2723,7 +3017,7 @@ func testPoliciesNameUnicode(t *testing.T, ds *Datastore) {
|
|||
policyEmoji, err := ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "💻"})
|
||||
require.NoError(t, err)
|
||||
err = ds.SavePolicy(
|
||||
context.Background(), &fleet.Policy{PolicyData: fleet.PolicyData{ID: policyEmoji.ID, Name: equivalentNames[1]}}, false,
|
||||
context.Background(), &fleet.Policy{PolicyData: fleet.PolicyData{ID: policyEmoji.ID, Name: equivalentNames[1]}}, false, false,
|
||||
)
|
||||
assert.True(t, isDuplicate(err), err)
|
||||
|
||||
|
|
@ -3270,10 +3564,10 @@ func testGetTeamHostsPolicyMemberships(t *testing.T, ds *Datastore) {
|
|||
//
|
||||
|
||||
team2Policy1.Platform = "darwin"
|
||||
err = ds.SavePolicy(ctx, team1Policy1, false)
|
||||
err = ds.SavePolicy(ctx, team1Policy1, false, true)
|
||||
require.NoError(t, err)
|
||||
team1Policy1.Platform = "darwin"
|
||||
err = ds.SavePolicy(ctx, team2Policy1, false)
|
||||
err = ds.SavePolicy(ctx, team2Policy1, false, true)
|
||||
require.NoError(t, err)
|
||||
|
||||
//
|
||||
|
|
|
|||
|
|
@ -602,7 +602,7 @@ type Datastore interface {
|
|||
// SavePolicy updates some fields of the given policy on the datastore.
|
||||
//
|
||||
// It is also used to update team policies.
|
||||
SavePolicy(ctx context.Context, p *Policy, shouldRemoveAllPolicyMemberships bool) error
|
||||
SavePolicy(ctx context.Context, p *Policy, shouldRemoveAllPolicyMemberships bool, removePolicyStats bool) error
|
||||
|
||||
ListGlobalPolicies(ctx context.Context, opts ListOptions) ([]*Policy, error)
|
||||
PoliciesByID(ctx context.Context, ids []uint) (map[uint]*Policy, error)
|
||||
|
|
|
|||
|
|
@ -435,7 +435,7 @@ type NewGlobalPolicyFunc func(ctx context.Context, authorID *uint, args fleet.Po
|
|||
|
||||
type PolicyFunc func(ctx context.Context, id uint) (*fleet.Policy, error)
|
||||
|
||||
type SavePolicyFunc func(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error
|
||||
type SavePolicyFunc func(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool, removePolicyStats bool) error
|
||||
|
||||
type ListGlobalPoliciesFunc func(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error)
|
||||
|
||||
|
|
@ -3729,11 +3729,11 @@ func (s *DataStore) Policy(ctx context.Context, id uint) (*fleet.Policy, error)
|
|||
return s.PolicyFunc(ctx, id)
|
||||
}
|
||||
|
||||
func (s *DataStore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool) error {
|
||||
func (s *DataStore) SavePolicy(ctx context.Context, p *fleet.Policy, shouldRemoveAllPolicyMemberships bool, removePolicyStats bool) error {
|
||||
s.mu.Lock()
|
||||
s.SavePolicyFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
return s.SavePolicyFunc(ctx, p, shouldRemoveAllPolicyMemberships)
|
||||
return s.SavePolicyFunc(ctx, p, shouldRemoveAllPolicyMemberships, removePolicyStats)
|
||||
}
|
||||
|
||||
func (s *DataStore) ListGlobalPolicies(ctx context.Context, opts fleet.ListOptions) ([]*fleet.Policy, error) {
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ func TestGlobalPoliciesAuth(t *testing.T) {
|
|||
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error {
|
||||
return nil
|
||||
}
|
||||
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool) error {
|
||||
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool, removePolicyStats bool) error {
|
||||
return nil
|
||||
}
|
||||
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
|
||||
|
|
|
|||
|
|
@ -2428,6 +2428,65 @@ func (s *integrationTestSuite) TestGlobalPoliciesProprietary() {
|
|||
assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount)
|
||||
assert.Equal(t, uint(0), policiesResponse.Policies[0].PassingHostCount)
|
||||
|
||||
// Record query executions
|
||||
require.NoError(
|
||||
t, s.ds.RecordPolicyQueryExecutions(
|
||||
context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now(), false,
|
||||
),
|
||||
)
|
||||
require.NoError(
|
||||
t, s.ds.RecordPolicyQueryExecutions(
|
||||
context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now(), false,
|
||||
),
|
||||
)
|
||||
// Update policy stats
|
||||
require.NoError(t, s.ds.UpdateHostPolicyCounts(context.Background()))
|
||||
|
||||
// Fetch policy to make sure stats are updated
|
||||
s.DoJSON("GET", "/api/latest/fleet/policies", nil, http.StatusOK, &policiesResponse)
|
||||
require.Len(t, policiesResponse.Policies, 1)
|
||||
assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount)
|
||||
assert.Equal(t, uint(1), policiesResponse.Policies[0].PassingHostCount)
|
||||
|
||||
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID)
|
||||
listHostsResp = listHostsResponse{}
|
||||
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
|
||||
require.Len(t, listHostsResp.Hosts, 1)
|
||||
|
||||
// Modify the platform for the policy, which should clear the policy stats
|
||||
mgpParams = modifyGlobalPolicyRequest{
|
||||
ModifyPolicyPayload: fleet.ModifyPolicyPayload{
|
||||
Platform: ptr.String("linux"),
|
||||
},
|
||||
}
|
||||
mgpResp = modifyGlobalPolicyResponse{}
|
||||
s.DoJSON("PATCH", fmt.Sprintf("/api/latest/fleet/policies/%d", gpResp.Policy.ID), mgpParams, http.StatusOK, &mgpResp)
|
||||
require.NotNil(t, gpResp.Policy)
|
||||
assert.Equal(t, "TestQuery4", mgpResp.Policy.Name)
|
||||
assert.Equal(t, "select * from users;", mgpResp.Policy.Query)
|
||||
assert.Equal(t, "Some description updated", mgpResp.Policy.Description)
|
||||
require.NotNil(t, mgpResp.Policy.Resolution)
|
||||
assert.Equal(t, "some global resolution updated", *mgpResp.Policy.Resolution)
|
||||
assert.Equal(t, "linux", mgpResp.Policy.Platform)
|
||||
assert.Equal(t, uint(0), mgpResp.Policy.FailingHostCount)
|
||||
assert.Equal(t, uint(0), mgpResp.Policy.PassingHostCount)
|
||||
|
||||
// Fetch policy to make sure stats are updated
|
||||
s.DoJSON("GET", "/api/latest/fleet/policies", nil, http.StatusOK, &policiesResponse)
|
||||
require.Len(t, policiesResponse.Policies, 1)
|
||||
assert.Equal(t, uint(0), policiesResponse.Policies[0].FailingHostCount)
|
||||
assert.Equal(t, uint(0), policiesResponse.Policies[0].PassingHostCount)
|
||||
|
||||
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID)
|
||||
listHostsResp = listHostsResponse{}
|
||||
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
|
||||
require.Len(t, listHostsResp.Hosts, 0)
|
||||
|
||||
listHostsURL = fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d&policy_response=failing", policiesResponse.Policies[0].ID)
|
||||
listHostsResp = listHostsResponse{}
|
||||
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
|
||||
require.Len(t, listHostsResp.Hosts, 0)
|
||||
|
||||
deletePolicyParams := deleteGlobalPoliciesRequest{IDs: []uint{policiesResponse.Policies[0].ID}}
|
||||
deletePolicyResp := deleteGlobalPoliciesResponse{}
|
||||
s.DoJSON("POST", "/api/latest/fleet/policies/delete", deletePolicyParams, http.StatusOK, &deletePolicyResp)
|
||||
|
|
|
|||
|
|
@ -368,7 +368,8 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
|
|||
})
|
||||
}
|
||||
|
||||
var shouldRemoveAll bool
|
||||
var removeAllMemberships bool
|
||||
var removeStats bool
|
||||
if p.Name != nil {
|
||||
policy.Name = *p.Name
|
||||
}
|
||||
|
|
@ -377,9 +378,8 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
|
|||
}
|
||||
if p.Query != nil {
|
||||
if policy.Query != *p.Query {
|
||||
shouldRemoveAll = true
|
||||
policy.FailingHostCount = 0
|
||||
policy.PassingHostCount = 0
|
||||
removeAllMemberships = true
|
||||
removeStats = true
|
||||
}
|
||||
policy.Query = *p.Query
|
||||
}
|
||||
|
|
@ -387,6 +387,9 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
|
|||
policy.Resolution = p.Resolution
|
||||
}
|
||||
if p.Platform != nil {
|
||||
if policy.Platform != *p.Platform {
|
||||
removeStats = true
|
||||
}
|
||||
policy.Platform = *p.Platform
|
||||
}
|
||||
if p.Critical != nil {
|
||||
|
|
@ -395,9 +398,13 @@ func (svc *Service) modifyPolicy(ctx context.Context, teamID *uint, id uint, p f
|
|||
if p.CalendarEventsEnabled != nil {
|
||||
policy.CalendarEventsEnabled = *p.CalendarEventsEnabled
|
||||
}
|
||||
if removeStats {
|
||||
policy.FailingHostCount = 0
|
||||
policy.PassingHostCount = 0
|
||||
}
|
||||
logging.WithExtras(ctx, "name", policy.Name, "sql", policy.Query)
|
||||
|
||||
err = svc.ds.SavePolicy(ctx, policy, shouldRemoveAll)
|
||||
err = svc.ds.SavePolicy(ctx, policy, removeAllMemberships, removeStats)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "saving policy")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ func TestTeamPoliciesAuth(t *testing.T) {
|
|||
}
|
||||
return nil, nil
|
||||
}
|
||||
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool) error {
|
||||
ds.SavePolicyFunc = func(ctx context.Context, p *fleet.Policy, shouldDeleteAll bool, removePolicyStats bool) error {
|
||||
return nil
|
||||
}
|
||||
ds.DeleteTeamPoliciesFunc = func(ctx context.Context, teamID uint, ids []uint) ([]uint, error) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue