From d321cfc68e62f3d80f61ba00afc3306b17256fc7 Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Wed, 12 Oct 2022 08:35:36 -0400 Subject: [PATCH] Add inherited policies to the team's list policies response payload (#8068) --- ...ssue-7879-add-inherited-policies-for-teams | 1 + docs/Using-Fleet/REST-API.md | 18 ++ server/datastore/mysql/policies.go | 86 ++++-- server/datastore/mysql/policies_test.go | 262 +++++++++++++----- server/fleet/datastore.go | 2 +- server/fleet/service.go | 2 +- server/mock/datastore_mock.go | 4 +- server/service/integration_core_test.go | 1 + server/service/integration_enterprise_test.go | 12 + server/service/team_policies.go | 15 +- server/service/team_policies_test.go | 6 +- 11 files changed, 298 insertions(+), 111 deletions(-) create mode 100644 changes/issue-7879-add-inherited-policies-for-teams diff --git a/changes/issue-7879-add-inherited-policies-for-teams b/changes/issue-7879-add-inherited-policies-for-teams new file mode 100644 index 0000000000..93bfe4227c --- /dev/null +++ b/changes/issue-7879-add-inherited-policies-for-teams @@ -0,0 +1 @@ +* Added the `inherited_policies` array to the `GET /teams/{team_id}/policies` endpoint that lists the global policies inherited by the team, along with the pass/fail counts only for hosts that belong to that team. diff --git a/docs/Using-Fleet/REST-API.md b/docs/Using-Fleet/REST-API.md index 82b6700d86..44c23ad9df 100644 --- a/docs/Using-Fleet/REST-API.md +++ b/docs/Using-Fleet/REST-API.md @@ -3506,6 +3506,24 @@ Team policies work the same as policies, but at the team level. "passing_host_count": 2300, "failing_host_count": 0 } + ], + "inherited_policies": [ + { + "id": 136, + "name": "Arbitrary Test Policy (all platforms) (all teams)", + "query": "SELECT 1 FROM osquery_info WHERE 1=1;", + "description": "If you're seeing this, mostly likely this is because someone is testing out failing policies in dogfood. You can ignore this.", + "author_id": 77, + "author_name": "Test Admin", + "author_email": "test@admin.com", + "team_id": null, + "resolution": "To make it pass, change \"1=0\" to \"1=1\". To make it fail, change \"1=1\" to \"1=0\".", + "platform": "darwin,windows,linux", + "created_at": "2022-08-04T19:30:18Z", + "updated_at": "2022-08-30T15:08:26Z", + "passing_host_count": 10, + "failing_host_count": 9 + } ] } ``` diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index 2cf25a8ce6..6d041ed21d 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -258,29 +258,55 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee } func (ds *Datastore) ListGlobalPolicies(ctx context.Context) ([]*fleet.Policy, error) { - return listPoliciesDB(ctx, ds.reader, nil) + return listPoliciesDB(ctx, ds.reader, nil, nil) } -func listPoliciesDB(ctx context.Context, q sqlx.QueryerContext, teamID *uint) ([]*fleet.Policy, error) { - teamWhere := "p.team_id is NULL" +// returns the list of policies associated with the provided teamID, or the +// global policies if teamID is nil. The pass/fail host counts are the totals +// regardless of hosts' team if countsForTeamID is nil, or the totals just for +// hosts that belong to the provided countsForTeamID if it is not nil. +func listPoliciesDB(ctx context.Context, q sqlx.QueryerContext, teamID, countsForTeamID *uint) ([]*fleet.Policy, error) { var args []interface{} + + counts := ` + (select count(*) from policy_membership where policy_id=p.id and passes=true) as passing_host_count, + (select count(*) from policy_membership where policy_id=p.id and passes=false) as failing_host_count +` + if countsForTeamID != nil { + counts = ` + (select count(*) from policy_membership pm inner join hosts h on pm.host_id = h.id where pm.policy_id=p.id and pm.passes=true and h.team_id = ?) as passing_host_count, + (select count(*) from policy_membership pm inner join hosts h on pm.host_id = h.id where pm.policy_id=p.id and pm.passes=false and h.team_id = ?) as failing_host_count +` + args = append(args, *countsForTeamID, *countsForTeamID) + } + + teamWhere := "p.team_id is NULL" if teamID != nil { teamWhere = "p.team_id = ?" args = append(args, *teamID) } + var policies []*fleet.Policy err := sqlx.SelectContext( ctx, q, &policies, - fmt.Sprintf(`SELECT p.*, - COALESCE(u.name, '') AS author_name, - COALESCE(u.email, '') AS author_email, - (select count(*) from policy_membership where policy_id=p.id and passes=true) as passing_host_count, - (select count(*) from policy_membership where policy_id=p.id and passes=false) as failing_host_count - FROM policies p - LEFT JOIN users u ON p.author_id = u.id - WHERE %s`, teamWhere), args..., + fmt.Sprintf(`SELECT p.id, + p.team_id, + p.resolution, + p.name, + p.query, + p.description, + p.author_id, + p.platforms, + p.created_at, + p.updated_at, + COALESCE(u.name, '') AS author_name, + COALESCE(u.email, '') AS author_email, + %s + FROM policies p + LEFT JOIN users u ON p.author_id = u.id + WHERE %s`, counts, teamWhere), args..., ) if err != nil { return nil, ctxerr.Wrap(ctx, err, "listing policies") @@ -289,14 +315,23 @@ func listPoliciesDB(ctx context.Context, q sqlx.QueryerContext, teamID *uint) ([ } func (ds *Datastore) PoliciesByID(ctx context.Context, ids []uint) (map[uint]*fleet.Policy, error) { - sql := `SELECT p.*, - COALESCE(u.name, '') AS author_name, - COALESCE(u.email, '') AS author_email, - (select count(*) from policy_membership where policy_id=p.id and passes=true) as passing_host_count, - (select count(*) from policy_membership where policy_id=p.id and passes=false) as failing_host_count - FROM policies p - LEFT JOIN users u ON p.author_id = u.id - WHERE p.id IN (?)` + sql := `SELECT p.id, + p.team_id, + p.resolution, + p.name, + p.query, + p.description, + p.author_id, + p.platforms, + p.created_at, + p.updated_at, + COALESCE(u.name, '') AS author_name, + COALESCE(u.email, '') AS author_email, + (select count(*) from policy_membership where policy_id=p.id and passes=true) as passing_host_count, + (select count(*) from policy_membership where policy_id=p.id and passes=false) as failing_host_count + FROM policies p + LEFT JOIN users u ON p.author_id = u.id + WHERE p.id IN (?)` query, args, err := sqlx.In(sql, ids) if err != nil { return nil, ctxerr.Wrap(ctx, err, "building query to get policies by ID") @@ -421,8 +456,17 @@ func (ds *Datastore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *u return policyDB(ctx, ds.writer, uint(lastIdInt64), &teamID) } -func (ds *Datastore) ListTeamPolicies(ctx context.Context, teamID uint) ([]*fleet.Policy, error) { - return listPoliciesDB(ctx, ds.reader, &teamID) +func (ds *Datastore) ListTeamPolicies(ctx context.Context, teamID uint) (teamPolicies, inheritedPolicies []*fleet.Policy, err error) { + teamPolicies, err = listPoliciesDB(ctx, ds.reader, &teamID, nil) + if err != nil { + return nil, nil, err + } + // get inherited (global) policies with counts of hosts for that team + inheritedPolicies, err = listPoliciesDB(ctx, ds.reader, nil, &teamID) + if err != nil { + return nil, nil, err + } + return teamPolicies, inheritedPolicies, err } func (ds *Datastore) DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) { diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index 31c9904a1b..a0d552f15c 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -192,8 +192,10 @@ func testPoliciesNewGlobalPolicyProprietary(t *testing.T, ds *Datastore) { } func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) { + ctx := context.Background() + user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true) - host1, err := ds.NewHost(context.Background(), &fleet.Host{ + host1, err := ds.NewHost(ctx, &fleet.Host{ OsqueryHostID: "1234", DetailUpdatedAt: time.Now(), LabelUpdatedAt: time.Now(), @@ -205,7 +207,7 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) { }) require.NoError(t, err) - host2, err := ds.NewHost(context.Background(), &fleet.Host{ + host2, err := ds.NewHost(ctx, &fleet.Host{ OsqueryHostID: "5679", DetailUpdatedAt: time.Now(), LabelUpdatedAt: time.Now(), @@ -217,14 +219,14 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) { }) require.NoError(t, err) - q, err := ds.NewQuery(context.Background(), &fleet.Query{ + q, err := ds.NewQuery(ctx, &fleet.Query{ Name: "query1", Description: "query1 desc", Query: "select 1;", Saved: true, }) require.NoError(t, err) - p, err := ds.NewGlobalPolicy(context.Background(), &user1.ID, fleet.PolicyPayload{ + p, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{ QueryID: &q.ID, }) require.NoError(t, err) @@ -235,14 +237,14 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) { require.NotNil(t, p.AuthorID) assert.Equal(t, user1.ID, *p.AuthorID) - q2, err := ds.NewQuery(context.Background(), &fleet.Query{ + q2, err := ds.NewQuery(ctx, &fleet.Query{ Name: "query2", Description: "query2 desc", Query: "select 42;", Saved: true, }) require.NoError(t, err) - p2, err := ds.NewGlobalPolicy(context.Background(), &user1.ID, fleet.PolicyPayload{ + p2, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{ QueryID: &q2.ID, }) require.NoError(t, err) @@ -253,55 +255,128 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) { require.NotNil(t, p2.AuthorID) assert.Equal(t, user1.ID, *p2.AuthorID) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{p.ID: nil}, time.Now(), deferred)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: nil}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{p2.ID: nil}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: nil}, time.Now(), deferred)) - policies, err := ds.ListGlobalPolicies(context.Background()) + policies, err := ds.ListGlobalPolicies(ctx) require.NoError(t, err) require.Len(t, policies, 2) + assert.Equal(t, p.ID, policies[0].ID) assert.Equal(t, uint(2), policies[0].PassingHostCount) assert.Equal(t, uint(0), policies[0].FailingHostCount) + assert.Equal(t, p2.ID, policies[1].ID) assert.Equal(t, uint(0), policies[1].PassingHostCount) assert.Equal(t, uint(0), policies[1].FailingHostCount) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{p2.ID: ptr.Bool(false)}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: ptr.Bool(false)}, time.Now(), deferred)) - policies, err = ds.ListGlobalPolicies(context.Background()) + policies, err = ds.ListGlobalPolicies(ctx) require.NoError(t, err) require.Len(t, policies, 2) + assert.Equal(t, p.ID, policies[0].ID) assert.Equal(t, uint(1), policies[0].PassingHostCount) assert.Equal(t, uint(1), policies[0].FailingHostCount) + assert.Equal(t, p2.ID, policies[1].ID) assert.Equal(t, uint(0), policies[1].PassingHostCount) assert.Equal(t, uint(1), policies[1].FailingHostCount) - policy, err := ds.Policy(context.Background(), policies[0].ID) + policy, err := ds.Policy(ctx, policies[0].ID) require.NoError(t, err) assert.Equal(t, policies[0], policy) - queries, err := ds.PolicyQueriesForHost(context.Background(), host1) + queries, err := ds.PolicyQueriesForHost(ctx, host1) require.NoError(t, err) require.Len(t, queries, 2) assert.Equal(t, q.Query, queries[fmt.Sprint(q.ID)]) assert.Equal(t, q2.Query, queries[fmt.Sprint(q2.ID)]) + + // create a couple teams and team-specific policies + team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"}) + require.NoError(t, err) + team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"}) + require.NoError(t, err) + + t1pol, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{ + Name: "team1pol", + Query: "SELECT 1", + }) + require.NoError(t, err) + t2pol, err := ds.NewTeamPolicy(ctx, team2.ID, &user1.ID, fleet.PolicyPayload{ + Name: "team2pol", + Query: "SELECT 2", + }) + require.NoError(t, err) + t2pol2, err := ds.NewTeamPolicy(ctx, team2.ID, &user1.ID, fleet.PolicyPayload{ + Name: "team2pol2", + Query: "SELECT 3", + }) + require.NoError(t, err) + + // create hosts in each team + host3, err := ds.EnrollHost(ctx, "3", "3", &team1.ID, 0) + require.NoError(t, err) + host4, err := ds.EnrollHost(ctx, "4", "4", &team2.ID, 0) + require.NoError(t, err) + host5, err := ds.EnrollHost(ctx, "5", "5", &team2.ID, 0) + require.NoError(t, err) + + // create some policy results + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{t1pol.ID: ptr.Bool(true), p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{t2pol.ID: ptr.Bool(false), t2pol2.ID: ptr.Bool(true), p.ID: ptr.Bool(false)}, time.Now(), deferred)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{t2pol.ID: ptr.Bool(true), t2pol2.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), deferred)) + + t1Pols, t1Inherited, err := ds.ListTeamPolicies(ctx, team1.ID) + require.NoError(t, err) + require.Len(t, t1Pols, 1) + assert.Equal(t, uint(1), t1Pols[0].PassingHostCount) + assert.Equal(t, uint(0), t1Pols[0].FailingHostCount) + + require.Len(t, t1Inherited, 2) + require.Equal(t, p.ID, t1Inherited[0].ID) + assert.Equal(t, uint(1), t1Inherited[0].PassingHostCount) + assert.Equal(t, uint(0), t1Inherited[0].FailingHostCount) + require.Equal(t, p2.ID, t1Inherited[1].ID) + assert.Equal(t, uint(0), t1Inherited[1].PassingHostCount) + assert.Equal(t, uint(1), t1Inherited[1].FailingHostCount) + + t2Pols, t2Inherited, err := ds.ListTeamPolicies(ctx, team2.ID) + require.NoError(t, err) + require.Len(t, t2Pols, 2) + require.Equal(t, t2pol.ID, t2Pols[0].ID) + assert.Equal(t, uint(1), t2Pols[0].PassingHostCount) + assert.Equal(t, uint(1), t2Pols[0].FailingHostCount) + require.Equal(t, t2pol2.ID, t2Pols[1].ID) + assert.Equal(t, uint(2), t2Pols[1].PassingHostCount) + assert.Equal(t, uint(0), t2Pols[1].FailingHostCount) + + require.Len(t, t2Inherited, 2) + require.Equal(t, p.ID, t2Inherited[0].ID) + assert.Equal(t, uint(0), t2Inherited[0].PassingHostCount) + assert.Equal(t, uint(1), t2Inherited[0].FailingHostCount) + require.Equal(t, p2.ID, t2Inherited[1].ID) + assert.Equal(t, uint(1), t2Inherited[1].PassingHostCount) + assert.Equal(t, uint(0), t2Inherited[1].FailingHostCount) } func testTeamPolicyLegacy(t *testing.T, ds *Datastore) { + ctx := context.Background() + user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true) - team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"}) + team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"}) require.NoError(t, err) - q, err := ds.NewQuery(context.Background(), &fleet.Query{ + q, err := ds.NewQuery(ctx, &fleet.Query{ Name: "query1", Description: "query1 desc", Query: "select 1;", @@ -309,10 +384,10 @@ func testTeamPolicyLegacy(t *testing.T, ds *Datastore) { }) require.NoError(t, err) - team2, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team2"}) + team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"}) require.NoError(t, err) - q2, err := ds.NewQuery(context.Background(), &fleet.Query{ + q2, err := ds.NewQuery(ctx, &fleet.Query{ Name: "query2", Description: "query2 desc", Query: "select 1;", @@ -320,15 +395,16 @@ func testTeamPolicyLegacy(t *testing.T, ds *Datastore) { }) require.NoError(t, err) - prevPolicies, err := ds.ListGlobalPolicies(context.Background()) + prevPolicies, err := ds.ListGlobalPolicies(ctx) require.NoError(t, err) + require.Len(t, prevPolicies, 0) - _, err = ds.NewTeamPolicy(context.Background(), 99999999, &user1.ID, fleet.PolicyPayload{ + _, err = ds.NewTeamPolicy(ctx, 99999999, &user1.ID, fleet.PolicyPayload{ QueryID: &q.ID, }) require.Error(t, err) - p, err := ds.NewTeamPolicy(context.Background(), team1.ID, &user1.ID, fleet.PolicyPayload{ + p, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{ QueryID: &q.ID, Resolution: "some resolution", }) @@ -343,11 +419,17 @@ func testTeamPolicyLegacy(t *testing.T, ds *Datastore) { require.NotNil(t, p.Resolution) assert.Equal(t, "some resolution", *p.Resolution) - globalPolicies, err := ds.ListGlobalPolicies(context.Background()) + gpol, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{ + Name: "global_1", + Query: "SELECT 1", + }) require.NoError(t, err) - require.Len(t, globalPolicies, len(prevPolicies)) - p2, err := ds.NewTeamPolicy(context.Background(), team2.ID, &user1.ID, fleet.PolicyPayload{ + globalPolicies, err := ds.ListGlobalPolicies(ctx) + require.NoError(t, err) + require.Len(t, globalPolicies, 1) + + p2, err := ds.NewTeamPolicy(ctx, team2.ID, &user1.ID, fleet.PolicyPayload{ QueryID: &q2.ID, }) require.NoError(t, err) @@ -358,7 +440,7 @@ func testTeamPolicyLegacy(t *testing.T, ds *Datastore) { require.NotNil(t, p2.AuthorID) assert.Equal(t, user1.ID, *p2.AuthorID) - teamPolicies, err := ds.ListTeamPolicies(context.Background(), team1.ID) + teamPolicies, inherited1, err := ds.ListTeamPolicies(ctx, team1.ID) require.NoError(t, err) require.Len(t, teamPolicies, 1) assert.Equal(t, q.Name, teamPolicies[0].Name) @@ -367,7 +449,10 @@ func testTeamPolicyLegacy(t *testing.T, ds *Datastore) { require.NotNil(t, teamPolicies[0].AuthorID) require.Equal(t, user1.ID, *teamPolicies[0].AuthorID) - team2Policies, err := ds.ListTeamPolicies(context.Background(), team2.ID) + require.Len(t, inherited1, 1) + require.Equal(t, gpol, inherited1[0]) + + team2Policies, inherited2, err := ds.ListTeamPolicies(ctx, team2.ID) require.NoError(t, err) require.Len(t, team2Policies, 1) assert.Equal(t, q2.Name, team2Policies[0].Name) @@ -376,12 +461,16 @@ func testTeamPolicyLegacy(t *testing.T, ds *Datastore) { require.NotNil(t, team2Policies[0].AuthorID) require.Equal(t, user1.ID, *team2Policies[0].AuthorID) - _, err = ds.DeleteTeamPolicies(context.Background(), team1.ID, []uint{teamPolicies[0].ID}) + require.Len(t, inherited2, 1) + require.Equal(t, gpol, inherited2[0]) + + _, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{teamPolicies[0].ID}) require.NoError(t, err) - teamPolicies, err = ds.ListTeamPolicies(context.Background(), team1.ID) + teamPolicies, inherited1, err = ds.ListTeamPolicies(ctx, team1.ID) require.NoError(t, err) require.Len(t, teamPolicies, 0) + require.Len(t, inherited1, 1) } func testTeamPolicyProprietary(t *testing.T, ds *Datastore) { @@ -392,7 +481,7 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) { require.NoError(t, err) ctx := context.Background() - _, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{ + gpol, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{ Name: "existing-query-global-1", Query: "select 1;", Description: "query1 desc", @@ -402,7 +491,9 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) { prevPolicies, err := ds.ListGlobalPolicies(ctx) require.NoError(t, err) + require.Len(t, prevPolicies, 1) + // team does not exist _, err = ds.NewTeamPolicy(ctx, 99999999, &user1.ID, fleet.PolicyPayload{ Name: "query1", Query: "select 1;", @@ -429,6 +520,7 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) { IsExists() bool } require.True(t, errors.As(err, &isExist) && isExist.IsExists(), err) + // Can't create a global policy with an existing name. _, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{ Name: "query1", @@ -436,6 +528,7 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) { }) require.Error(t, err) require.True(t, errors.As(err, &isExist) && isExist.IsExists(), err) + // Can't create a team policy with an existing global name. _, err = ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{ Name: "existing-query-global-1", @@ -472,7 +565,7 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) { require.NotNil(t, p2.AuthorID) assert.Equal(t, user1.ID, *p2.AuthorID) - teamPolicies, err := ds.ListTeamPolicies(ctx, team1.ID) + teamPolicies, inherited1, err := ds.ListTeamPolicies(ctx, team1.ID) require.NoError(t, err) require.Len(t, teamPolicies, 1) assert.Equal(t, "query1", teamPolicies[0].Name) @@ -483,7 +576,10 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) { require.NotNil(t, teamPolicies[0].AuthorID) require.Equal(t, user1.ID, *teamPolicies[0].AuthorID) - team2Policies, err := ds.ListTeamPolicies(context.Background(), team2.ID) + require.Len(t, inherited1, 1) + require.Equal(t, gpol, inherited1[0]) + + team2Policies, inherited2, err := ds.ListTeamPolicies(ctx, team2.ID) require.NoError(t, err) require.Len(t, team2Policies, 1) assert.Equal(t, "query2", team2Policies[0].Name) @@ -494,6 +590,9 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) { require.NotNil(t, team2Policies[0].AuthorID) require.Equal(t, user1.ID, *team2Policies[0].AuthorID) + require.Len(t, inherited2, 1) + require.Equal(t, gpol, inherited2[0]) + // Can't create a policy with the same name on the same team. p3, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{ Name: "query1", @@ -504,11 +603,14 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) { require.Error(t, err) require.Nil(t, p3) - _, err = ds.DeleteTeamPolicies(context.Background(), team1.ID, []uint{teamPolicies[0].ID}) + _, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{teamPolicies[0].ID}) require.NoError(t, err) - teamPolicies, err = ds.ListTeamPolicies(ctx, team1.ID) + + teamPolicies, inherited1, err = ds.ListTeamPolicies(ctx, team1.ID) require.NoError(t, err) require.Len(t, teamPolicies, 0) + require.Len(t, inherited1, 1) + require.Equal(t, gpol, inherited1[0]) // Now the name is available and we can create the policy in the team. _, err = ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{ @@ -518,7 +620,8 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) { Resolution: "query2 other resolution", }) require.NoError(t, err) - teamPolicies, err = ds.ListTeamPolicies(ctx, team1.ID) + + teamPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID) require.NoError(t, err) require.Len(t, teamPolicies, 1) assert.Equal(t, "query1", teamPolicies[0].Name) @@ -923,14 +1026,15 @@ func testPoliciesByID(t *testing.T, ds *Datastore) { } func testTeamPolicyTransfer(t *testing.T, ds *Datastore) { + ctx := context.Background() user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true) - team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: t.Name() + "team1"}) + team1, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team1"}) require.NoError(t, err) - team2, err := ds.NewTeam(context.Background(), &fleet.Team{Name: t.Name() + "team2"}) + team2, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team2"}) require.NoError(t, err) - host1, err := ds.NewHost(context.Background(), &fleet.Host{ + host1, err := ds.NewHost(ctx, &fleet.Host{ OsqueryHostID: "1234", DetailUpdatedAt: time.Now(), LabelUpdatedAt: time.Now(), @@ -941,83 +1045,89 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) { Hostname: "foo.local", }) require.NoError(t, err) - host2, err := ds.EnrollHost(context.Background(), "2", "2", &team1.ID, 0) + host2, err := ds.EnrollHost(ctx, "2", "2", &team1.ID, 0) require.NoError(t, err) - require.NoError(t, ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host1.ID})) - host1, err = ds.Host(context.Background(), host1.ID) + require.NoError(t, ds.AddHostsToTeam(ctx, &team1.ID, []uint{host1.ID})) + host1, err = ds.Host(ctx, host1.ID) require.NoError(t, err) - tq, err := ds.NewQuery(context.Background(), &fleet.Query{ + tq, err := ds.NewQuery(ctx, &fleet.Query{ Name: "query1", Description: "query1 desc", Query: "select 1;", Saved: true, }) require.NoError(t, err) - teamPolicy, err := ds.NewTeamPolicy(context.Background(), team1.ID, &user1.ID, fleet.PolicyPayload{ + team1Policy, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{ QueryID: &tq.ID, }) require.NoError(t, err) - gq, err := ds.NewQuery(context.Background(), &fleet.Query{ + gq, err := ds.NewQuery(ctx, &fleet.Query{ Name: "query2", Description: "query2 desc", Query: "select 2;", Saved: true, }) require.NoError(t, err) - globalPolicy, err := ds.NewGlobalPolicy(context.Background(), &user1.ID, fleet.PolicyPayload{ + globalPolicy, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{ QueryID: &gq.ID, }) require.NoError(t, err) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) - checkPassingCount := func(expectedCount, expectedGlobalCount uint) { - policies, err := ds.ListTeamPolicies(context.Background(), team1.ID) + checkPassingCount := func(tm1, tm1Inherited, tm2Inherited, global uint) { + policies, inherited, err := ds.ListTeamPolicies(ctx, team1.ID) require.NoError(t, err) require.Len(t, policies, 1) + assert.Equal(t, tm1, policies[0].PassingHostCount) + require.Len(t, inherited, 1) + assert.Equal(t, tm1Inherited, inherited[0].PassingHostCount) - assert.Equal(t, expectedCount, policies[0].PassingHostCount) + policies, inherited, err = ds.ListTeamPolicies(ctx, team2.ID) + require.NoError(t, err) + require.Len(t, policies, 0) // team 2 has no policies of its own + require.Len(t, inherited, 1) + assert.Equal(t, tm2Inherited, inherited[0].PassingHostCount) - policies, err = ds.ListGlobalPolicies(context.Background()) + policies, err = ds.ListGlobalPolicies(ctx) require.NoError(t, err) require.Len(t, policies, 1) - assert.Equal(t, expectedGlobalCount, policies[0].PassingHostCount) - - policies, err = ds.ListTeamPolicies(context.Background(), team2.ID) - require.NoError(t, err) - require.Len(t, policies, 0) + assert.Equal(t, global, policies[0].PassingHostCount) } - checkPassingCount(2, 2) + // both hosts belong to team1 and pass the team and the global policy + checkPassingCount(2, 2, 0, 2) // team policies are removed when AddHostsToTeam is called - require.NoError(t, ds.AddHostsToTeam(context.Background(), ptr.Uint(team2.ID), []uint{host1.ID})) - checkPassingCount(1, 2) + require.NoError(t, ds.AddHostsToTeam(ctx, ptr.Uint(team2.ID), []uint{host1.ID})) + // host2 passes tm1 and the global (so team1's inherited too), host1 passes the team2's inherited and the global + checkPassingCount(1, 1, 1, 2) // all host policies are removed when a host is enrolled in the same team - _, err = ds.EnrollHost(context.Background(), "2", "2", &team1.ID, 0) + _, err = ds.EnrollHost(ctx, "2", "2", &team1.ID, 0) require.NoError(t, err) - checkPassingCount(0, 1) + checkPassingCount(0, 0, 1, 1) // team policies are removed if the host is enrolled in a different team - _, err = ds.EnrollHost(context.Background(), "2", "2", &team2.ID, 0) + _, err = ds.EnrollHost(ctx, "2", "2", &team2.ID, 0) require.NoError(t, err) - checkPassingCount(0, 1) + // both hosts are now in team2 + checkPassingCount(0, 0, 1, 1) // team policies are removed if the host is re-enrolled without a team - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) - checkPassingCount(1, 2) + require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false)) + checkPassingCount(1, 0, 2, 2) // all host policies are removed when a host is re-enrolled - _, err = ds.EnrollHost(context.Background(), "2", "2", nil, 0) + _, err = ds.EnrollHost(ctx, "2", "2", nil, 0) require.NoError(t, err) - checkPassingCount(0, 1) + checkPassingCount(0, 0, 1, 1) } func testApplyPolicySpec(t *testing.T, ds *Datastore) { @@ -1065,7 +1175,7 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { assert.Equal(t, "some resolution", *policies[0].Resolution) assert.Equal(t, "", policies[0].Platform) - teamPolicies, err := ds.ListTeamPolicies(ctx, team1.ID) + teamPolicies, _, err := ds.ListTeamPolicies(ctx, team1.ID) require.NoError(t, err) require.Len(t, teamPolicies, 2) assert.Equal(t, "query2", teamPolicies[0].Name) @@ -1117,7 +1227,7 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { policies, err = ds.ListGlobalPolicies(ctx) require.NoError(t, err) require.Len(t, policies, 1) - teamPolicies, err = ds.ListTeamPolicies(ctx, team1.ID) + teamPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID) require.NoError(t, err) require.Len(t, teamPolicies, 2) @@ -1153,7 +1263,7 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) { assert.Equal(t, "some resolution updated", *policies[0].Resolution) assert.Equal(t, "", policies[0].Platform) - teamPolicies, err = ds.ListTeamPolicies(ctx, team1.ID) + teamPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID) require.NoError(t, err) require.Len(t, teamPolicies, 2) @@ -1529,7 +1639,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) { require.NoError(t, err) require.Len(t, gpols, 2) // load the team policies - tpols, err := ds.ListTeamPolicies(ctx, tm.ID) + tpols, _, err := ds.ListTeamPolicies(ctx, tm.ID) require.NoError(t, err) require.Len(t, tpols, 2) diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index d00d86c76c..0e0b5d710f 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -478,7 +478,7 @@ type Datastore interface { // Team Policies NewTeamPolicy(ctx context.Context, teamID uint, authorID *uint, args PolicyPayload) (*Policy, error) - ListTeamPolicies(ctx context.Context, teamID uint) ([]*Policy, error) + ListTeamPolicies(ctx context.Context, teamID uint) (teamPolicies, inheritedPolicies []*Policy, err error) DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) TeamPolicy(ctx context.Context, teamID uint, policyID uint) (*Policy, error) diff --git a/server/fleet/service.go b/server/fleet/service.go index 5178e7a037..2dfa257b02 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -507,7 +507,7 @@ type Service interface { // Team Policies NewTeamPolicy(ctx context.Context, teamID uint, p PolicyPayload) (*Policy, error) - ListTeamPolicies(ctx context.Context, teamID uint) ([]*Policy, error) + ListTeamPolicies(ctx context.Context, teamID uint) (teamPolicies, inheritedPolicies []*Policy, err error) DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) ModifyTeamPolicy(ctx context.Context, teamID uint, id uint, p ModifyPolicyPayload) (*Policy, error) GetTeamPolicyByIDQueries(ctx context.Context, teamID uint, policyID uint) (*Policy, error) diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 106da8236d..bd9f05d45a 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -373,7 +373,7 @@ type DeleteSoftwareVulnerabilitiesFunc func(ctx context.Context, vulnerabilities type NewTeamPolicyFunc func(ctx context.Context, teamID uint, authorID *uint, args fleet.PolicyPayload) (*fleet.Policy, error) -type ListTeamPoliciesFunc func(ctx context.Context, teamID uint) ([]*fleet.Policy, error) +type ListTeamPoliciesFunc func(ctx context.Context, teamID uint) (teamPolicies []*fleet.Policy, inheritedPolicies []*fleet.Policy, err error) type DeleteTeamPoliciesFunc func(ctx context.Context, teamID uint, ids []uint) ([]uint, error) @@ -2075,7 +2075,7 @@ func (s *DataStore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *ui return s.NewTeamPolicyFunc(ctx, teamID, authorID, args) } -func (s *DataStore) ListTeamPolicies(ctx context.Context, teamID uint) ([]*fleet.Policy, error) { +func (s *DataStore) ListTeamPolicies(ctx context.Context, teamID uint) (teamPolicies []*fleet.Policy, inheritedPolicies []*fleet.Policy, err error) { s.ListTeamPoliciesFuncInvoked = true return s.ListTeamPoliciesFunc(ctx, teamID) } diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index a93d016f40..224c316375 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -1801,6 +1801,7 @@ func (s *integrationTestSuite) TestTeamPoliciesProprietary() { require.NotNil(t, policiesResponse.Policies[0].Resolution) assert.Equal(t, "some team resolution updated", *policiesResponse.Policies[0].Resolution) assert.Equal(t, "darwin", policiesResponse.Policies[0].Platform) + require.Len(t, policiesResponse.InheritedPolicies, 0) listHostsURL := fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d", policiesResponse.Policies[0].ID) listHostsResp := listHostsResponse{} diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index a689936216..b76b658222 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -284,6 +284,15 @@ func (s *integrationEnterpriseTestSuite) TestTeamPolicies() { ts := listTeamPoliciesResponse{} s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/teams/%d/policies", team1.ID), nil, http.StatusOK, &ts) require.Len(t, ts.Policies, 0) + require.Len(t, ts.InheritedPolicies, 0) + + // create a global policy + gpol, err := s.ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "TestGlobalPolicy", Query: "SELECT 1"}) + require.NoError(t, err) + defer func() { + _, err := s.ds.DeleteGlobalPolicies(context.Background(), []uint{gpol.ID}) + require.NoError(t, err) + }() qr, err := s.ds.NewQuery(context.Background(), &fleet.Query{Name: "TestQuery2", Description: "Some description", Query: "select * from osquery;", ObserverCanRun: true}) require.NoError(t, err) @@ -303,6 +312,9 @@ func (s *integrationEnterpriseTestSuite) TestTeamPolicies() { assert.Equal(t, "Some description", ts.Policies[0].Description) require.NotNil(t, ts.Policies[0].Resolution) assert.Equal(t, "some team resolution", *ts.Policies[0].Resolution) + require.Len(t, ts.InheritedPolicies, 1) + assert.Equal(t, gpol.Name, ts.InheritedPolicies[0].Name) + assert.Equal(t, gpol.ID, ts.InheritedPolicies[0].ID) deletePolicyParams := deleteTeamPoliciesRequest{IDs: []uint{ts.Policies[0].ID}} deletePolicyResp := deleteTeamPoliciesResponse{} diff --git a/server/service/team_policies.go b/server/service/team_policies.go index 826a004021..a039691a4c 100644 --- a/server/service/team_policies.go +++ b/server/service/team_policies.go @@ -96,32 +96,33 @@ type listTeamPoliciesRequest struct { } type listTeamPoliciesResponse struct { - Policies []*fleet.Policy `json:"policies,omitempty"` - Err error `json:"error,omitempty"` + Policies []*fleet.Policy `json:"policies,omitempty"` + InheritedPolicies []*fleet.Policy `json:"inherited_policies,omitempty"` + Err error `json:"error,omitempty"` } func (r listTeamPoliciesResponse) error() error { return r.Err } func listTeamPoliciesEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) { req := request.(*listTeamPoliciesRequest) - resp, err := svc.ListTeamPolicies(ctx, req.TeamID) + tmPols, inheritedPols, err := svc.ListTeamPolicies(ctx, req.TeamID) if err != nil { return listTeamPoliciesResponse{Err: err}, nil } - return listTeamPoliciesResponse{Policies: resp}, nil + return listTeamPoliciesResponse{Policies: tmPols, InheritedPolicies: inheritedPols}, nil } -func (svc Service) ListTeamPolicies(ctx context.Context, teamID uint) ([]*fleet.Policy, error) { +func (svc *Service) ListTeamPolicies(ctx context.Context, teamID uint) (teamPolicies, inheritedPolicies []*fleet.Policy, err error) { if err := svc.authz.Authorize(ctx, &fleet.Policy{ PolicyData: fleet.PolicyData{ TeamID: ptr.Uint(teamID), }, }, fleet.ActionRead); err != nil { - return nil, err + return nil, nil, err } if _, err := svc.ds.Team(ctx, teamID); err != nil { - return nil, ctxerr.Wrapf(ctx, err, "loading team %d", teamID) + return nil, nil, ctxerr.Wrapf(ctx, err, "loading team %d", teamID) } return svc.ds.ListTeamPolicies(ctx, teamID) diff --git a/server/service/team_policies_test.go b/server/service/team_policies_test.go index ac65c685c7..4785cce921 100644 --- a/server/service/team_policies_test.go +++ b/server/service/team_policies_test.go @@ -24,8 +24,8 @@ func TestTeamPoliciesAuth(t *testing.T) { }, }, nil } - ds.ListTeamPoliciesFunc = func(ctx context.Context, teamID uint) ([]*fleet.Policy, error) { - return nil, nil + ds.ListTeamPoliciesFunc = func(ctx context.Context, teamID uint) (tpol, ipol []*fleet.Policy, err error) { + return nil, nil, nil } ds.PoliciesByIDFunc = func(ctx context.Context, ids []uint) (map[uint]*fleet.Policy, error) { return nil, nil @@ -149,7 +149,7 @@ func TestTeamPoliciesAuth(t *testing.T) { }) checkAuthErr(t, tt.shouldFailWrite, err) - _, err = svc.ListTeamPolicies(ctx, 1) + _, _, err = svc.ListTeamPolicies(ctx, 1) checkAuthErr(t, tt.shouldFailRead, err) _, err = svc.GetTeamPolicyByIDQueries(ctx, 1, 1)