Delete policies for hosts in teams before transferring them (#2383)

* Delete policies for hosts in teams before transferring them

* Add missing error check
This commit is contained in:
Tomas Touceda 2021-10-05 15:48:26 -03:00 committed by GitHub
parent ddc6b300d4
commit 70cf7aa0a0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 20 deletions

View file

@ -0,0 +1 @@
* When transferring a host from team to team, clear the policy results for that host.

View file

@ -810,20 +810,29 @@ func (d *Datastore) AddHostsToTeam(ctx context.Context, teamID *uint, hostIDs []
return nil
}
sql := `
UPDATE hosts SET team_id = ?
WHERE id IN (?)
`
sql, args, err := sqlx.In(sql, teamID, hostIDs)
if err != nil {
return errors.Wrap(err, "sqlx.In AddHostsToTeam")
}
return d.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
// hosts can only be in one team, so if there's a policy that has a team id and a result from one of our hosts
// it can only be from the previous team they are being transferred from
query, args, err := sqlx.In(`DELETE FROM policy_membership_history
WHERE policy_id IN (SELECT id FROM policies WHERE team_id IS NOT NULL) AND host_id IN (?)`, hostIDs)
if err != nil {
return errors.Wrap(err, "add host to team sqlx in")
}
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return errors.Wrap(err, "exec AddHostsToTeam delete policy membership history")
}
if _, err := d.writer.ExecContext(ctx, sql, args...); err != nil {
return errors.Wrap(err, "exec AddHostsToTeam")
}
query, args, err = sqlx.In(`UPDATE hosts SET team_id = ? WHERE id IN (?)`, teamID, hostIDs)
if err != nil {
return errors.Wrap(err, "sqlx.In AddHostsToTeam")
}
return nil
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return errors.Wrap(err, "exec AddHostsToTeam")
}
return nil
})
}
func saveHostAdditionalDB(ctx context.Context, exec sqlx.ExecerContext, host *fleet.Host) error {

View file

@ -21,6 +21,9 @@ func TestPolicies(t *testing.T) {
}{
{"NewGlobalPolicy", testPoliciesNewGlobalPolicy},
{"MembershipView", testPoliciesMembershipView},
{"TeamPolicy", testTeamPolicy},
{"PolicyQueriesForHost", testPolicyQueriesForHost},
{"TeamPolicyTransfer", testTeamPolicyTransfer},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
@ -163,10 +166,7 @@ func testPoliciesMembershipView(t *testing.T, ds *Datastore) {
assert.Equal(t, q2.Query, queries[fmt.Sprint(q2.ID)])
}
func TestTeamPolicy(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
func testTeamPolicy(t *testing.T, ds *Datastore) {
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
require.NoError(t, err)
@ -225,10 +225,7 @@ func TestTeamPolicy(t *testing.T) {
require.Len(t, teamPolicies, 0)
}
func TestPolicyQueriesForHost(t *testing.T) {
ds := CreateMySQLDS(t)
defer ds.Close()
func testPolicyQueriesForHost(t *testing.T, ds *Datastore) {
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
require.NoError(t, err)
@ -291,3 +288,68 @@ func TestPolicyQueriesForHost(t *testing.T) {
require.Len(t, queries, 1)
assert.Equal(t, q.Query, queries[fmt.Sprint(q.ID)])
}
func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: t.Name() + "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(context.Background(), &fleet.Team{Name: t.Name() + "team2"})
require.NoError(t, err)
host1, err := ds.NewHost(context.Background(), &fleet.Host{
OsqueryHostID: "1234",
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: "1",
UUID: "1",
Hostname: "foo.local",
})
require.NoError(t, err)
require.NoError(t, ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host1.ID}))
host1, err = ds.Host(context.Background(), host1.ID)
require.NoError(t, err)
q, err := ds.NewQuery(context.Background(), &fleet.Query{
Name: "query1",
Description: "query1 desc",
Query: "select 1;",
Saved: true,
})
require.NoError(t, err)
teamPolicy, err := ds.NewTeamPolicy(context.Background(), team1.ID, q.ID)
require.NoError(t, err)
globalPolicy, err := ds.NewGlobalPolicy(context.Background(), q.ID)
require.NoError(t, err)
require.NoError(t, ds.RecordPolicyQueryExecutions(
context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now()))
require.NoError(t, ds.RecordPolicyQueryExecutions(
context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now()))
checkPassingCount := func(expectedCount uint) {
policies, err := ds.ListTeamPolicies(context.Background(), team1.ID)
require.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, expectedCount, policies[0].PassingHostCount)
policies, err = ds.ListGlobalPolicies(context.Background())
require.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, uint(1), policies[0].PassingHostCount)
policies, err = ds.ListTeamPolicies(context.Background(), team2.ID)
require.NoError(t, err)
require.Len(t, policies, 0)
}
checkPassingCount(1)
require.NoError(t, ds.AddHostsToTeam(context.Background(), ptr.Uint(team2.ID), []uint{host1.ID}))
checkPassingCount(0)
}