diff --git a/changes/clear-host-policies-on-team-transfer b/changes/clear-host-policies-on-team-transfer new file mode 100644 index 0000000000..f5387fb03f --- /dev/null +++ b/changes/clear-host-policies-on-team-transfer @@ -0,0 +1 @@ +* When transferring a host from team to team, clear the policy results for that host. diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index 77cfff78f7..855c8029bb 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -810,20 +810,29 @@ func (d *Datastore) AddHostsToTeam(ctx context.Context, teamID *uint, hostIDs [] return nil } - sql := ` - UPDATE hosts SET team_id = ? - WHERE id IN (?) - ` - sql, args, err := sqlx.In(sql, teamID, hostIDs) - if err != nil { - return errors.Wrap(err, "sqlx.In AddHostsToTeam") - } + return d.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { + // hosts can only be in one team, so if there's a policy that has a team id and a result from one of our hosts + // it can only be from the previous team they are being transferred from + query, args, err := sqlx.In(`DELETE FROM policy_membership_history + WHERE policy_id IN (SELECT id FROM policies WHERE team_id IS NOT NULL) AND host_id IN (?)`, hostIDs) + if err != nil { + return errors.Wrap(err, "add host to team sqlx in") + } + if _, err := tx.ExecContext(ctx, query, args...); err != nil { + return errors.Wrap(err, "exec AddHostsToTeam delete policy membership history") + } - if _, err := d.writer.ExecContext(ctx, sql, args...); err != nil { - return errors.Wrap(err, "exec AddHostsToTeam") - } + query, args, err = sqlx.In(`UPDATE hosts SET team_id = ? WHERE id IN (?)`, teamID, hostIDs) + if err != nil { + return errors.Wrap(err, "sqlx.In AddHostsToTeam") + } - return nil + if _, err := tx.ExecContext(ctx, query, args...); err != nil { + return errors.Wrap(err, "exec AddHostsToTeam") + } + + return nil + }) } func saveHostAdditionalDB(ctx context.Context, exec sqlx.ExecerContext, host *fleet.Host) error { diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index 8897d98995..007350b972 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -21,6 +21,9 @@ func TestPolicies(t *testing.T) { }{ {"NewGlobalPolicy", testPoliciesNewGlobalPolicy}, {"MembershipView", testPoliciesMembershipView}, + {"TeamPolicy", testTeamPolicy}, + {"PolicyQueriesForHost", testPolicyQueriesForHost}, + {"TeamPolicyTransfer", testTeamPolicyTransfer}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { @@ -163,10 +166,7 @@ func testPoliciesMembershipView(t *testing.T, ds *Datastore) { assert.Equal(t, q2.Query, queries[fmt.Sprint(q2.ID)]) } -func TestTeamPolicy(t *testing.T) { - ds := CreateMySQLDS(t) - defer ds.Close() - +func testTeamPolicy(t *testing.T, ds *Datastore) { team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"}) require.NoError(t, err) @@ -225,10 +225,7 @@ func TestTeamPolicy(t *testing.T) { require.Len(t, teamPolicies, 0) } -func TestPolicyQueriesForHost(t *testing.T) { - ds := CreateMySQLDS(t) - defer ds.Close() - +func testPolicyQueriesForHost(t *testing.T, ds *Datastore) { team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"}) require.NoError(t, err) @@ -291,3 +288,68 @@ func TestPolicyQueriesForHost(t *testing.T) { require.Len(t, queries, 1) assert.Equal(t, q.Query, queries[fmt.Sprint(q.ID)]) } + +func testTeamPolicyTransfer(t *testing.T, ds *Datastore) { + team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: t.Name() + "team1"}) + require.NoError(t, err) + + team2, err := ds.NewTeam(context.Background(), &fleet.Team{Name: t.Name() + "team2"}) + require.NoError(t, err) + + host1, err := ds.NewHost(context.Background(), &fleet.Host{ + OsqueryHostID: "1234", + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: "1", + UUID: "1", + Hostname: "foo.local", + }) + require.NoError(t, err) + + require.NoError(t, ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host1.ID})) + host1, err = ds.Host(context.Background(), host1.ID) + require.NoError(t, err) + + q, err := ds.NewQuery(context.Background(), &fleet.Query{ + Name: "query1", + Description: "query1 desc", + Query: "select 1;", + Saved: true, + }) + require.NoError(t, err) + teamPolicy, err := ds.NewTeamPolicy(context.Background(), team1.ID, q.ID) + require.NoError(t, err) + + globalPolicy, err := ds.NewGlobalPolicy(context.Background(), q.ID) + require.NoError(t, err) + + require.NoError(t, ds.RecordPolicyQueryExecutions( + context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now())) + require.NoError(t, ds.RecordPolicyQueryExecutions( + context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now())) + + checkPassingCount := func(expectedCount uint) { + policies, err := ds.ListTeamPolicies(context.Background(), team1.ID) + require.NoError(t, err) + require.Len(t, policies, 1) + + assert.Equal(t, expectedCount, policies[0].PassingHostCount) + + policies, err = ds.ListGlobalPolicies(context.Background()) + require.NoError(t, err) + require.Len(t, policies, 1) + assert.Equal(t, uint(1), policies[0].PassingHostCount) + + policies, err = ds.ListTeamPolicies(context.Background(), team2.ID) + require.NoError(t, err) + require.Len(t, policies, 0) + } + + checkPassingCount(1) + + require.NoError(t, ds.AddHostsToTeam(context.Background(), ptr.Uint(team2.ID), []uint{host1.ID})) + + checkPassingCount(0) +}