mirror of
https://github.com/fleetdm/fleet
synced 2026-05-23 08:58:41 +00:00
Delete policies for hosts in teams before transferring them (#2383)
* Delete policies for hosts in teams before transferring them * Add missing error check
This commit is contained in:
parent
ddc6b300d4
commit
70cf7aa0a0
3 changed files with 92 additions and 20 deletions
1
changes/clear-host-policies-on-team-transfer
Normal file
1
changes/clear-host-policies-on-team-transfer
Normal file
|
|
@ -0,0 +1 @@
|
|||
* When transferring a host from team to team, clear the policy results for that host.
|
||||
|
|
@ -810,20 +810,29 @@ func (d *Datastore) AddHostsToTeam(ctx context.Context, teamID *uint, hostIDs []
|
|||
return nil
|
||||
}
|
||||
|
||||
sql := `
|
||||
UPDATE hosts SET team_id = ?
|
||||
WHERE id IN (?)
|
||||
`
|
||||
sql, args, err := sqlx.In(sql, teamID, hostIDs)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "sqlx.In AddHostsToTeam")
|
||||
}
|
||||
return d.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
||||
// hosts can only be in one team, so if there's a policy that has a team id and a result from one of our hosts
|
||||
// it can only be from the previous team they are being transferred from
|
||||
query, args, err := sqlx.In(`DELETE FROM policy_membership_history
|
||||
WHERE policy_id IN (SELECT id FROM policies WHERE team_id IS NOT NULL) AND host_id IN (?)`, hostIDs)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "add host to team sqlx in")
|
||||
}
|
||||
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
|
||||
return errors.Wrap(err, "exec AddHostsToTeam delete policy membership history")
|
||||
}
|
||||
|
||||
if _, err := d.writer.ExecContext(ctx, sql, args...); err != nil {
|
||||
return errors.Wrap(err, "exec AddHostsToTeam")
|
||||
}
|
||||
query, args, err = sqlx.In(`UPDATE hosts SET team_id = ? WHERE id IN (?)`, teamID, hostIDs)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "sqlx.In AddHostsToTeam")
|
||||
}
|
||||
|
||||
return nil
|
||||
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
|
||||
return errors.Wrap(err, "exec AddHostsToTeam")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func saveHostAdditionalDB(ctx context.Context, exec sqlx.ExecerContext, host *fleet.Host) error {
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ func TestPolicies(t *testing.T) {
|
|||
}{
|
||||
{"NewGlobalPolicy", testPoliciesNewGlobalPolicy},
|
||||
{"MembershipView", testPoliciesMembershipView},
|
||||
{"TeamPolicy", testTeamPolicy},
|
||||
{"PolicyQueriesForHost", testPolicyQueriesForHost},
|
||||
{"TeamPolicyTransfer", testTeamPolicyTransfer},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
|
|
@ -163,10 +166,7 @@ func testPoliciesMembershipView(t *testing.T, ds *Datastore) {
|
|||
assert.Equal(t, q2.Query, queries[fmt.Sprint(q2.ID)])
|
||||
}
|
||||
|
||||
func TestTeamPolicy(t *testing.T) {
|
||||
ds := CreateMySQLDS(t)
|
||||
defer ds.Close()
|
||||
|
||||
func testTeamPolicy(t *testing.T, ds *Datastore) {
|
||||
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -225,10 +225,7 @@ func TestTeamPolicy(t *testing.T) {
|
|||
require.Len(t, teamPolicies, 0)
|
||||
}
|
||||
|
||||
func TestPolicyQueriesForHost(t *testing.T) {
|
||||
ds := CreateMySQLDS(t)
|
||||
defer ds.Close()
|
||||
|
||||
func testPolicyQueriesForHost(t *testing.T, ds *Datastore) {
|
||||
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -291,3 +288,68 @@ func TestPolicyQueriesForHost(t *testing.T) {
|
|||
require.Len(t, queries, 1)
|
||||
assert.Equal(t, q.Query, queries[fmt.Sprint(q.ID)])
|
||||
}
|
||||
|
||||
func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
|
||||
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: t.Name() + "team1"})
|
||||
require.NoError(t, err)
|
||||
|
||||
team2, err := ds.NewTeam(context.Background(), &fleet.Team{Name: t.Name() + "team2"})
|
||||
require.NoError(t, err)
|
||||
|
||||
host1, err := ds.NewHost(context.Background(), &fleet.Host{
|
||||
OsqueryHostID: "1234",
|
||||
DetailUpdatedAt: time.Now(),
|
||||
LabelUpdatedAt: time.Now(),
|
||||
PolicyUpdatedAt: time.Now(),
|
||||
SeenTime: time.Now(),
|
||||
NodeKey: "1",
|
||||
UUID: "1",
|
||||
Hostname: "foo.local",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host1.ID}))
|
||||
host1, err = ds.Host(context.Background(), host1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
q, err := ds.NewQuery(context.Background(), &fleet.Query{
|
||||
Name: "query1",
|
||||
Description: "query1 desc",
|
||||
Query: "select 1;",
|
||||
Saved: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
teamPolicy, err := ds.NewTeamPolicy(context.Background(), team1.ID, q.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
globalPolicy, err := ds.NewGlobalPolicy(context.Background(), q.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(
|
||||
context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now()))
|
||||
require.NoError(t, ds.RecordPolicyQueryExecutions(
|
||||
context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now()))
|
||||
|
||||
checkPassingCount := func(expectedCount uint) {
|
||||
policies, err := ds.ListTeamPolicies(context.Background(), team1.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 1)
|
||||
|
||||
assert.Equal(t, expectedCount, policies[0].PassingHostCount)
|
||||
|
||||
policies, err = ds.ListGlobalPolicies(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 1)
|
||||
assert.Equal(t, uint(1), policies[0].PassingHostCount)
|
||||
|
||||
policies, err = ds.ListTeamPolicies(context.Background(), team2.ID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, policies, 0)
|
||||
}
|
||||
|
||||
checkPassingCount(1)
|
||||
|
||||
require.NoError(t, ds.AddHostsToTeam(context.Background(), ptr.Uint(team2.ID), []uint{host1.ID}))
|
||||
|
||||
checkPassingCount(0)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue