Add inherited policies to the team's list policies response payload (#8068)

This commit is contained in:
Martin Angers 2022-10-12 08:35:36 -04:00 committed by GitHub
parent 42c47a6fa7
commit d321cfc68e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 298 additions and 111 deletions

View file

@ -0,0 +1 @@
* Added the `inherited_policies` array to the `GET /teams/{team_id}/policies` endpoint that lists the global policies inherited by the team, along with the pass/fail counts only for hosts that belong to that team.

View file

@ -3506,6 +3506,24 @@ Team policies work the same as policies, but at the team level.
"passing_host_count": 2300,
"failing_host_count": 0
}
],
"inherited_policies": [
{
"id": 136,
"name": "Arbitrary Test Policy (all platforms) (all teams)",
"query": "SELECT 1 FROM osquery_info WHERE 1=1;",
"description": "If you're seeing this, mostly likely this is because someone is testing out failing policies in dogfood. You can ignore this.",
"author_id": 77,
"author_name": "Test Admin",
"author_email": "test@admin.com",
"team_id": null,
"resolution": "To make it pass, change \"1=0\" to \"1=1\". To make it fail, change \"1=1\" to \"1=0\".",
"platform": "darwin,windows,linux",
"created_at": "2022-08-04T19:30:18Z",
"updated_at": "2022-08-30T15:08:26Z",
"passing_host_count": 10,
"failing_host_count": 9
}
]
}
```

View file

@ -258,29 +258,55 @@ func (ds *Datastore) RecordPolicyQueryExecutions(ctx context.Context, host *flee
}
func (ds *Datastore) ListGlobalPolicies(ctx context.Context) ([]*fleet.Policy, error) {
return listPoliciesDB(ctx, ds.reader, nil)
return listPoliciesDB(ctx, ds.reader, nil, nil)
}
func listPoliciesDB(ctx context.Context, q sqlx.QueryerContext, teamID *uint) ([]*fleet.Policy, error) {
teamWhere := "p.team_id is NULL"
// returns the list of policies associated with the provided teamID, or the
// global policies if teamID is nil. The pass/fail host counts are the totals
// regardless of hosts' team if countsForTeamID is nil, or the totals just for
// hosts that belong to the provided countsForTeamID if it is not nil.
func listPoliciesDB(ctx context.Context, q sqlx.QueryerContext, teamID, countsForTeamID *uint) ([]*fleet.Policy, error) {
var args []interface{}
counts := `
(select count(*) from policy_membership where policy_id=p.id and passes=true) as passing_host_count,
(select count(*) from policy_membership where policy_id=p.id and passes=false) as failing_host_count
`
if countsForTeamID != nil {
counts = `
(select count(*) from policy_membership pm inner join hosts h on pm.host_id = h.id where pm.policy_id=p.id and pm.passes=true and h.team_id = ?) as passing_host_count,
(select count(*) from policy_membership pm inner join hosts h on pm.host_id = h.id where pm.policy_id=p.id and pm.passes=false and h.team_id = ?) as failing_host_count
`
args = append(args, *countsForTeamID, *countsForTeamID)
}
teamWhere := "p.team_id is NULL"
if teamID != nil {
teamWhere = "p.team_id = ?"
args = append(args, *teamID)
}
var policies []*fleet.Policy
err := sqlx.SelectContext(
ctx,
q,
&policies,
fmt.Sprintf(`SELECT p.*,
COALESCE(u.name, '<deleted>') AS author_name,
COALESCE(u.email, '') AS author_email,
(select count(*) from policy_membership where policy_id=p.id and passes=true) as passing_host_count,
(select count(*) from policy_membership where policy_id=p.id and passes=false) as failing_host_count
FROM policies p
LEFT JOIN users u ON p.author_id = u.id
WHERE %s`, teamWhere), args...,
fmt.Sprintf(`SELECT p.id,
p.team_id,
p.resolution,
p.name,
p.query,
p.description,
p.author_id,
p.platforms,
p.created_at,
p.updated_at,
COALESCE(u.name, '<deleted>') AS author_name,
COALESCE(u.email, '') AS author_email,
%s
FROM policies p
LEFT JOIN users u ON p.author_id = u.id
WHERE %s`, counts, teamWhere), args...,
)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "listing policies")
@ -289,14 +315,23 @@ func listPoliciesDB(ctx context.Context, q sqlx.QueryerContext, teamID *uint) ([
}
func (ds *Datastore) PoliciesByID(ctx context.Context, ids []uint) (map[uint]*fleet.Policy, error) {
sql := `SELECT p.*,
COALESCE(u.name, '<deleted>') AS author_name,
COALESCE(u.email, '') AS author_email,
(select count(*) from policy_membership where policy_id=p.id and passes=true) as passing_host_count,
(select count(*) from policy_membership where policy_id=p.id and passes=false) as failing_host_count
FROM policies p
LEFT JOIN users u ON p.author_id = u.id
WHERE p.id IN (?)`
sql := `SELECT p.id,
p.team_id,
p.resolution,
p.name,
p.query,
p.description,
p.author_id,
p.platforms,
p.created_at,
p.updated_at,
COALESCE(u.name, '<deleted>') AS author_name,
COALESCE(u.email, '') AS author_email,
(select count(*) from policy_membership where policy_id=p.id and passes=true) as passing_host_count,
(select count(*) from policy_membership where policy_id=p.id and passes=false) as failing_host_count
FROM policies p
LEFT JOIN users u ON p.author_id = u.id
WHERE p.id IN (?)`
query, args, err := sqlx.In(sql, ids)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "building query to get policies by ID")
@ -421,8 +456,17 @@ func (ds *Datastore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *u
return policyDB(ctx, ds.writer, uint(lastIdInt64), &teamID)
}
func (ds *Datastore) ListTeamPolicies(ctx context.Context, teamID uint) ([]*fleet.Policy, error) {
return listPoliciesDB(ctx, ds.reader, &teamID)
func (ds *Datastore) ListTeamPolicies(ctx context.Context, teamID uint) (teamPolicies, inheritedPolicies []*fleet.Policy, err error) {
teamPolicies, err = listPoliciesDB(ctx, ds.reader, &teamID, nil)
if err != nil {
return nil, nil, err
}
// get inherited (global) policies with counts of hosts for that team
inheritedPolicies, err = listPoliciesDB(ctx, ds.reader, nil, &teamID)
if err != nil {
return nil, nil, err
}
return teamPolicies, inheritedPolicies, err
}
func (ds *Datastore) DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) {

View file

@ -192,8 +192,10 @@ func testPoliciesNewGlobalPolicyProprietary(t *testing.T, ds *Datastore) {
}
func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
host1, err := ds.NewHost(context.Background(), &fleet.Host{
host1, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: "1234",
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
@ -205,7 +207,7 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
})
require.NoError(t, err)
host2, err := ds.NewHost(context.Background(), &fleet.Host{
host2, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: "5679",
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
@ -217,14 +219,14 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
})
require.NoError(t, err)
q, err := ds.NewQuery(context.Background(), &fleet.Query{
q, err := ds.NewQuery(ctx, &fleet.Query{
Name: "query1",
Description: "query1 desc",
Query: "select 1;",
Saved: true,
})
require.NoError(t, err)
p, err := ds.NewGlobalPolicy(context.Background(), &user1.ID, fleet.PolicyPayload{
p, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
QueryID: &q.ID,
})
require.NoError(t, err)
@ -235,14 +237,14 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
require.NotNil(t, p.AuthorID)
assert.Equal(t, user1.ID, *p.AuthorID)
q2, err := ds.NewQuery(context.Background(), &fleet.Query{
q2, err := ds.NewQuery(ctx, &fleet.Query{
Name: "query2",
Description: "query2 desc",
Query: "select 42;",
Saved: true,
})
require.NoError(t, err)
p2, err := ds.NewGlobalPolicy(context.Background(), &user1.ID, fleet.PolicyPayload{
p2, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
QueryID: &q2.ID,
})
require.NoError(t, err)
@ -253,55 +255,128 @@ func testPoliciesMembershipView(deferred bool, t *testing.T, ds *Datastore) {
require.NotNil(t, p2.AuthorID)
assert.Equal(t, user1.ID, *p2.AuthorID)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{p.ID: nil}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: nil}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{p2.ID: nil}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: nil}, time.Now(), deferred))
policies, err := ds.ListGlobalPolicies(context.Background())
policies, err := ds.ListGlobalPolicies(ctx)
require.NoError(t, err)
require.Len(t, policies, 2)
assert.Equal(t, p.ID, policies[0].ID)
assert.Equal(t, uint(2), policies[0].PassingHostCount)
assert.Equal(t, uint(0), policies[0].FailingHostCount)
assert.Equal(t, p2.ID, policies[1].ID)
assert.Equal(t, uint(0), policies[1].PassingHostCount)
assert.Equal(t, uint(0), policies[1].FailingHostCount)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{p2.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{p2.ID: ptr.Bool(false)}, time.Now(), deferred))
policies, err = ds.ListGlobalPolicies(context.Background())
policies, err = ds.ListGlobalPolicies(ctx)
require.NoError(t, err)
require.Len(t, policies, 2)
assert.Equal(t, p.ID, policies[0].ID)
assert.Equal(t, uint(1), policies[0].PassingHostCount)
assert.Equal(t, uint(1), policies[0].FailingHostCount)
assert.Equal(t, p2.ID, policies[1].ID)
assert.Equal(t, uint(0), policies[1].PassingHostCount)
assert.Equal(t, uint(1), policies[1].FailingHostCount)
policy, err := ds.Policy(context.Background(), policies[0].ID)
policy, err := ds.Policy(ctx, policies[0].ID)
require.NoError(t, err)
assert.Equal(t, policies[0], policy)
queries, err := ds.PolicyQueriesForHost(context.Background(), host1)
queries, err := ds.PolicyQueriesForHost(ctx, host1)
require.NoError(t, err)
require.Len(t, queries, 2)
assert.Equal(t, q.Query, queries[fmt.Sprint(q.ID)])
assert.Equal(t, q2.Query, queries[fmt.Sprint(q2.ID)])
// create a couple teams and team-specific policies
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
t1pol, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "team1pol",
Query: "SELECT 1",
})
require.NoError(t, err)
t2pol, err := ds.NewTeamPolicy(ctx, team2.ID, &user1.ID, fleet.PolicyPayload{
Name: "team2pol",
Query: "SELECT 2",
})
require.NoError(t, err)
t2pol2, err := ds.NewTeamPolicy(ctx, team2.ID, &user1.ID, fleet.PolicyPayload{
Name: "team2pol2",
Query: "SELECT 3",
})
require.NoError(t, err)
// create hosts in each team
host3, err := ds.EnrollHost(ctx, "3", "3", &team1.ID, 0)
require.NoError(t, err)
host4, err := ds.EnrollHost(ctx, "4", "4", &team2.ID, 0)
require.NoError(t, err)
host5, err := ds.EnrollHost(ctx, "5", "5", &team2.ID, 0)
require.NoError(t, err)
// create some policy results
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host3, map[uint]*bool{t1pol.ID: ptr.Bool(true), p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host4, map[uint]*bool{t2pol.ID: ptr.Bool(false), t2pol2.ID: ptr.Bool(true), p.ID: ptr.Bool(false)}, time.Now(), deferred))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host5, map[uint]*bool{t2pol.ID: ptr.Bool(true), t2pol2.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), deferred))
t1Pols, t1Inherited, err := ds.ListTeamPolicies(ctx, team1.ID)
require.NoError(t, err)
require.Len(t, t1Pols, 1)
assert.Equal(t, uint(1), t1Pols[0].PassingHostCount)
assert.Equal(t, uint(0), t1Pols[0].FailingHostCount)
require.Len(t, t1Inherited, 2)
require.Equal(t, p.ID, t1Inherited[0].ID)
assert.Equal(t, uint(1), t1Inherited[0].PassingHostCount)
assert.Equal(t, uint(0), t1Inherited[0].FailingHostCount)
require.Equal(t, p2.ID, t1Inherited[1].ID)
assert.Equal(t, uint(0), t1Inherited[1].PassingHostCount)
assert.Equal(t, uint(1), t1Inherited[1].FailingHostCount)
t2Pols, t2Inherited, err := ds.ListTeamPolicies(ctx, team2.ID)
require.NoError(t, err)
require.Len(t, t2Pols, 2)
require.Equal(t, t2pol.ID, t2Pols[0].ID)
assert.Equal(t, uint(1), t2Pols[0].PassingHostCount)
assert.Equal(t, uint(1), t2Pols[0].FailingHostCount)
require.Equal(t, t2pol2.ID, t2Pols[1].ID)
assert.Equal(t, uint(2), t2Pols[1].PassingHostCount)
assert.Equal(t, uint(0), t2Pols[1].FailingHostCount)
require.Len(t, t2Inherited, 2)
require.Equal(t, p.ID, t2Inherited[0].ID)
assert.Equal(t, uint(0), t2Inherited[0].PassingHostCount)
assert.Equal(t, uint(1), t2Inherited[0].FailingHostCount)
require.Equal(t, p2.ID, t2Inherited[1].ID)
assert.Equal(t, uint(1), t2Inherited[1].PassingHostCount)
assert.Equal(t, uint(0), t2Inherited[1].FailingHostCount)
}
func testTeamPolicyLegacy(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"})
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: "team1"})
require.NoError(t, err)
q, err := ds.NewQuery(context.Background(), &fleet.Query{
q, err := ds.NewQuery(ctx, &fleet.Query{
Name: "query1",
Description: "query1 desc",
Query: "select 1;",
@ -309,10 +384,10 @@ func testTeamPolicyLegacy(t *testing.T, ds *Datastore) {
})
require.NoError(t, err)
team2, err := ds.NewTeam(context.Background(), &fleet.Team{Name: "team2"})
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: "team2"})
require.NoError(t, err)
q2, err := ds.NewQuery(context.Background(), &fleet.Query{
q2, err := ds.NewQuery(ctx, &fleet.Query{
Name: "query2",
Description: "query2 desc",
Query: "select 1;",
@ -320,15 +395,16 @@ func testTeamPolicyLegacy(t *testing.T, ds *Datastore) {
})
require.NoError(t, err)
prevPolicies, err := ds.ListGlobalPolicies(context.Background())
prevPolicies, err := ds.ListGlobalPolicies(ctx)
require.NoError(t, err)
require.Len(t, prevPolicies, 0)
_, err = ds.NewTeamPolicy(context.Background(), 99999999, &user1.ID, fleet.PolicyPayload{
_, err = ds.NewTeamPolicy(ctx, 99999999, &user1.ID, fleet.PolicyPayload{
QueryID: &q.ID,
})
require.Error(t, err)
p, err := ds.NewTeamPolicy(context.Background(), team1.ID, &user1.ID, fleet.PolicyPayload{
p, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
QueryID: &q.ID,
Resolution: "some resolution",
})
@ -343,11 +419,17 @@ func testTeamPolicyLegacy(t *testing.T, ds *Datastore) {
require.NotNil(t, p.Resolution)
assert.Equal(t, "some resolution", *p.Resolution)
globalPolicies, err := ds.ListGlobalPolicies(context.Background())
gpol, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "global_1",
Query: "SELECT 1",
})
require.NoError(t, err)
require.Len(t, globalPolicies, len(prevPolicies))
p2, err := ds.NewTeamPolicy(context.Background(), team2.ID, &user1.ID, fleet.PolicyPayload{
globalPolicies, err := ds.ListGlobalPolicies(ctx)
require.NoError(t, err)
require.Len(t, globalPolicies, 1)
p2, err := ds.NewTeamPolicy(ctx, team2.ID, &user1.ID, fleet.PolicyPayload{
QueryID: &q2.ID,
})
require.NoError(t, err)
@ -358,7 +440,7 @@ func testTeamPolicyLegacy(t *testing.T, ds *Datastore) {
require.NotNil(t, p2.AuthorID)
assert.Equal(t, user1.ID, *p2.AuthorID)
teamPolicies, err := ds.ListTeamPolicies(context.Background(), team1.ID)
teamPolicies, inherited1, err := ds.ListTeamPolicies(ctx, team1.ID)
require.NoError(t, err)
require.Len(t, teamPolicies, 1)
assert.Equal(t, q.Name, teamPolicies[0].Name)
@ -367,7 +449,10 @@ func testTeamPolicyLegacy(t *testing.T, ds *Datastore) {
require.NotNil(t, teamPolicies[0].AuthorID)
require.Equal(t, user1.ID, *teamPolicies[0].AuthorID)
team2Policies, err := ds.ListTeamPolicies(context.Background(), team2.ID)
require.Len(t, inherited1, 1)
require.Equal(t, gpol, inherited1[0])
team2Policies, inherited2, err := ds.ListTeamPolicies(ctx, team2.ID)
require.NoError(t, err)
require.Len(t, team2Policies, 1)
assert.Equal(t, q2.Name, team2Policies[0].Name)
@ -376,12 +461,16 @@ func testTeamPolicyLegacy(t *testing.T, ds *Datastore) {
require.NotNil(t, team2Policies[0].AuthorID)
require.Equal(t, user1.ID, *team2Policies[0].AuthorID)
_, err = ds.DeleteTeamPolicies(context.Background(), team1.ID, []uint{teamPolicies[0].ID})
require.Len(t, inherited2, 1)
require.Equal(t, gpol, inherited2[0])
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{teamPolicies[0].ID})
require.NoError(t, err)
teamPolicies, err = ds.ListTeamPolicies(context.Background(), team1.ID)
teamPolicies, inherited1, err = ds.ListTeamPolicies(ctx, team1.ID)
require.NoError(t, err)
require.Len(t, teamPolicies, 0)
require.Len(t, inherited1, 1)
}
func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
@ -392,7 +481,7 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
require.NoError(t, err)
ctx := context.Background()
_, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
gpol, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "existing-query-global-1",
Query: "select 1;",
Description: "query1 desc",
@ -402,7 +491,9 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
prevPolicies, err := ds.ListGlobalPolicies(ctx)
require.NoError(t, err)
require.Len(t, prevPolicies, 1)
// team does not exist
_, err = ds.NewTeamPolicy(ctx, 99999999, &user1.ID, fleet.PolicyPayload{
Name: "query1",
Query: "select 1;",
@ -429,6 +520,7 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
IsExists() bool
}
require.True(t, errors.As(err, &isExist) && isExist.IsExists(), err)
// Can't create a global policy with an existing name.
_, err = ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
Name: "query1",
@ -436,6 +528,7 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
})
require.Error(t, err)
require.True(t, errors.As(err, &isExist) && isExist.IsExists(), err)
// Can't create a team policy with an existing global name.
_, err = ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "existing-query-global-1",
@ -472,7 +565,7 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
require.NotNil(t, p2.AuthorID)
assert.Equal(t, user1.ID, *p2.AuthorID)
teamPolicies, err := ds.ListTeamPolicies(ctx, team1.ID)
teamPolicies, inherited1, err := ds.ListTeamPolicies(ctx, team1.ID)
require.NoError(t, err)
require.Len(t, teamPolicies, 1)
assert.Equal(t, "query1", teamPolicies[0].Name)
@ -483,7 +576,10 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
require.NotNil(t, teamPolicies[0].AuthorID)
require.Equal(t, user1.ID, *teamPolicies[0].AuthorID)
team2Policies, err := ds.ListTeamPolicies(context.Background(), team2.ID)
require.Len(t, inherited1, 1)
require.Equal(t, gpol, inherited1[0])
team2Policies, inherited2, err := ds.ListTeamPolicies(ctx, team2.ID)
require.NoError(t, err)
require.Len(t, team2Policies, 1)
assert.Equal(t, "query2", team2Policies[0].Name)
@ -494,6 +590,9 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
require.NotNil(t, team2Policies[0].AuthorID)
require.Equal(t, user1.ID, *team2Policies[0].AuthorID)
require.Len(t, inherited2, 1)
require.Equal(t, gpol, inherited2[0])
// Can't create a policy with the same name on the same team.
p3, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
Name: "query1",
@ -504,11 +603,14 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
require.Error(t, err)
require.Nil(t, p3)
_, err = ds.DeleteTeamPolicies(context.Background(), team1.ID, []uint{teamPolicies[0].ID})
_, err = ds.DeleteTeamPolicies(ctx, team1.ID, []uint{teamPolicies[0].ID})
require.NoError(t, err)
teamPolicies, err = ds.ListTeamPolicies(ctx, team1.ID)
teamPolicies, inherited1, err = ds.ListTeamPolicies(ctx, team1.ID)
require.NoError(t, err)
require.Len(t, teamPolicies, 0)
require.Len(t, inherited1, 1)
require.Equal(t, gpol, inherited1[0])
// Now the name is available and we can create the policy in the team.
_, err = ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
@ -518,7 +620,8 @@ func testTeamPolicyProprietary(t *testing.T, ds *Datastore) {
Resolution: "query2 other resolution",
})
require.NoError(t, err)
teamPolicies, err = ds.ListTeamPolicies(ctx, team1.ID)
teamPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID)
require.NoError(t, err)
require.Len(t, teamPolicies, 1)
assert.Equal(t, "query1", teamPolicies[0].Name)
@ -923,14 +1026,15 @@ func testPoliciesByID(t *testing.T, ds *Datastore) {
}
func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
ctx := context.Background()
user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true)
team1, err := ds.NewTeam(context.Background(), &fleet.Team{Name: t.Name() + "team1"})
team1, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team1"})
require.NoError(t, err)
team2, err := ds.NewTeam(context.Background(), &fleet.Team{Name: t.Name() + "team2"})
team2, err := ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team2"})
require.NoError(t, err)
host1, err := ds.NewHost(context.Background(), &fleet.Host{
host1, err := ds.NewHost(ctx, &fleet.Host{
OsqueryHostID: "1234",
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
@ -941,83 +1045,89 @@ func testTeamPolicyTransfer(t *testing.T, ds *Datastore) {
Hostname: "foo.local",
})
require.NoError(t, err)
host2, err := ds.EnrollHost(context.Background(), "2", "2", &team1.ID, 0)
host2, err := ds.EnrollHost(ctx, "2", "2", &team1.ID, 0)
require.NoError(t, err)
require.NoError(t, ds.AddHostsToTeam(context.Background(), &team1.ID, []uint{host1.ID}))
host1, err = ds.Host(context.Background(), host1.ID)
require.NoError(t, ds.AddHostsToTeam(ctx, &team1.ID, []uint{host1.ID}))
host1, err = ds.Host(ctx, host1.ID)
require.NoError(t, err)
tq, err := ds.NewQuery(context.Background(), &fleet.Query{
tq, err := ds.NewQuery(ctx, &fleet.Query{
Name: "query1",
Description: "query1 desc",
Query: "select 1;",
Saved: true,
})
require.NoError(t, err)
teamPolicy, err := ds.NewTeamPolicy(context.Background(), team1.ID, &user1.ID, fleet.PolicyPayload{
team1Policy, err := ds.NewTeamPolicy(ctx, team1.ID, &user1.ID, fleet.PolicyPayload{
QueryID: &tq.ID,
})
require.NoError(t, err)
gq, err := ds.NewQuery(context.Background(), &fleet.Query{
gq, err := ds.NewQuery(ctx, &fleet.Query{
Name: "query2",
Description: "query2 desc",
Query: "select 2;",
Saved: true,
})
require.NoError(t, err)
globalPolicy, err := ds.NewGlobalPolicy(context.Background(), &user1.ID, fleet.PolicyPayload{
globalPolicy, err := ds.NewGlobalPolicy(ctx, &user1.ID, fleet.PolicyPayload{
QueryID: &gq.ID,
})
require.NoError(t, err)
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host1, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{teamPolicy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host1, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(false), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
checkPassingCount := func(expectedCount, expectedGlobalCount uint) {
policies, err := ds.ListTeamPolicies(context.Background(), team1.ID)
checkPassingCount := func(tm1, tm1Inherited, tm2Inherited, global uint) {
policies, inherited, err := ds.ListTeamPolicies(ctx, team1.ID)
require.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, tm1, policies[0].PassingHostCount)
require.Len(t, inherited, 1)
assert.Equal(t, tm1Inherited, inherited[0].PassingHostCount)
assert.Equal(t, expectedCount, policies[0].PassingHostCount)
policies, inherited, err = ds.ListTeamPolicies(ctx, team2.ID)
require.NoError(t, err)
require.Len(t, policies, 0) // team 2 has no policies of its own
require.Len(t, inherited, 1)
assert.Equal(t, tm2Inherited, inherited[0].PassingHostCount)
policies, err = ds.ListGlobalPolicies(context.Background())
policies, err = ds.ListGlobalPolicies(ctx)
require.NoError(t, err)
require.Len(t, policies, 1)
assert.Equal(t, expectedGlobalCount, policies[0].PassingHostCount)
policies, err = ds.ListTeamPolicies(context.Background(), team2.ID)
require.NoError(t, err)
require.Len(t, policies, 0)
assert.Equal(t, global, policies[0].PassingHostCount)
}
checkPassingCount(2, 2)
// both hosts belong to team1 and pass the team and the global policy
checkPassingCount(2, 2, 0, 2)
// team policies are removed when AddHostsToTeam is called
require.NoError(t, ds.AddHostsToTeam(context.Background(), ptr.Uint(team2.ID), []uint{host1.ID}))
checkPassingCount(1, 2)
require.NoError(t, ds.AddHostsToTeam(ctx, ptr.Uint(team2.ID), []uint{host1.ID}))
// host2 passes tm1 and the global (so team1's inherited too), host1 passes the team2's inherited and the global
checkPassingCount(1, 1, 1, 2)
// all host policies are removed when a host is enrolled in the same team
_, err = ds.EnrollHost(context.Background(), "2", "2", &team1.ID, 0)
_, err = ds.EnrollHost(ctx, "2", "2", &team1.ID, 0)
require.NoError(t, err)
checkPassingCount(0, 1)
checkPassingCount(0, 0, 1, 1)
// team policies are removed if the host is enrolled in a different team
_, err = ds.EnrollHost(context.Background(), "2", "2", &team2.ID, 0)
_, err = ds.EnrollHost(ctx, "2", "2", &team2.ID, 0)
require.NoError(t, err)
checkPassingCount(0, 1)
// both hosts are now in team2
checkPassingCount(0, 0, 1, 1)
// team policies are removed if the host is re-enrolled without a team
require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{teamPolicy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
checkPassingCount(1, 2)
require.NoError(t, ds.RecordPolicyQueryExecutions(ctx, host2, map[uint]*bool{team1Policy.ID: ptr.Bool(true), globalPolicy.ID: ptr.Bool(true)}, time.Now(), false))
checkPassingCount(1, 0, 2, 2)
// all host policies are removed when a host is re-enrolled
_, err = ds.EnrollHost(context.Background(), "2", "2", nil, 0)
_, err = ds.EnrollHost(ctx, "2", "2", nil, 0)
require.NoError(t, err)
checkPassingCount(0, 1)
checkPassingCount(0, 0, 1, 1)
}
func testApplyPolicySpec(t *testing.T, ds *Datastore) {
@ -1065,7 +1175,7 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) {
assert.Equal(t, "some resolution", *policies[0].Resolution)
assert.Equal(t, "", policies[0].Platform)
teamPolicies, err := ds.ListTeamPolicies(ctx, team1.ID)
teamPolicies, _, err := ds.ListTeamPolicies(ctx, team1.ID)
require.NoError(t, err)
require.Len(t, teamPolicies, 2)
assert.Equal(t, "query2", teamPolicies[0].Name)
@ -1117,7 +1227,7 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) {
policies, err = ds.ListGlobalPolicies(ctx)
require.NoError(t, err)
require.Len(t, policies, 1)
teamPolicies, err = ds.ListTeamPolicies(ctx, team1.ID)
teamPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID)
require.NoError(t, err)
require.Len(t, teamPolicies, 2)
@ -1153,7 +1263,7 @@ func testApplyPolicySpec(t *testing.T, ds *Datastore) {
assert.Equal(t, "some resolution updated", *policies[0].Resolution)
assert.Equal(t, "", policies[0].Platform)
teamPolicies, err = ds.ListTeamPolicies(ctx, team1.ID)
teamPolicies, _, err = ds.ListTeamPolicies(ctx, team1.ID)
require.NoError(t, err)
require.Len(t, teamPolicies, 2)
@ -1529,7 +1639,7 @@ func testPolicyPlatformUpdate(t *testing.T, ds *Datastore) {
require.NoError(t, err)
require.Len(t, gpols, 2)
// load the team policies
tpols, err := ds.ListTeamPolicies(ctx, tm.ID)
tpols, _, err := ds.ListTeamPolicies(ctx, tm.ID)
require.NoError(t, err)
require.Len(t, tpols, 2)

View file

@ -478,7 +478,7 @@ type Datastore interface {
// Team Policies
NewTeamPolicy(ctx context.Context, teamID uint, authorID *uint, args PolicyPayload) (*Policy, error)
ListTeamPolicies(ctx context.Context, teamID uint) ([]*Policy, error)
ListTeamPolicies(ctx context.Context, teamID uint) (teamPolicies, inheritedPolicies []*Policy, err error)
DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error)
TeamPolicy(ctx context.Context, teamID uint, policyID uint) (*Policy, error)

View file

@ -507,7 +507,7 @@ type Service interface {
// Team Policies
NewTeamPolicy(ctx context.Context, teamID uint, p PolicyPayload) (*Policy, error)
ListTeamPolicies(ctx context.Context, teamID uint) ([]*Policy, error)
ListTeamPolicies(ctx context.Context, teamID uint) (teamPolicies, inheritedPolicies []*Policy, err error)
DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error)
ModifyTeamPolicy(ctx context.Context, teamID uint, id uint, p ModifyPolicyPayload) (*Policy, error)
GetTeamPolicyByIDQueries(ctx context.Context, teamID uint, policyID uint) (*Policy, error)

View file

@ -373,7 +373,7 @@ type DeleteSoftwareVulnerabilitiesFunc func(ctx context.Context, vulnerabilities
type NewTeamPolicyFunc func(ctx context.Context, teamID uint, authorID *uint, args fleet.PolicyPayload) (*fleet.Policy, error)
type ListTeamPoliciesFunc func(ctx context.Context, teamID uint) ([]*fleet.Policy, error)
type ListTeamPoliciesFunc func(ctx context.Context, teamID uint) (teamPolicies []*fleet.Policy, inheritedPolicies []*fleet.Policy, err error)
type DeleteTeamPoliciesFunc func(ctx context.Context, teamID uint, ids []uint) ([]uint, error)
@ -2075,7 +2075,7 @@ func (s *DataStore) NewTeamPolicy(ctx context.Context, teamID uint, authorID *ui
return s.NewTeamPolicyFunc(ctx, teamID, authorID, args)
}
func (s *DataStore) ListTeamPolicies(ctx context.Context, teamID uint) ([]*fleet.Policy, error) {
func (s *DataStore) ListTeamPolicies(ctx context.Context, teamID uint) (teamPolicies []*fleet.Policy, inheritedPolicies []*fleet.Policy, err error) {
s.ListTeamPoliciesFuncInvoked = true
return s.ListTeamPoliciesFunc(ctx, teamID)
}

View file

@ -1801,6 +1801,7 @@ func (s *integrationTestSuite) TestTeamPoliciesProprietary() {
require.NotNil(t, policiesResponse.Policies[0].Resolution)
assert.Equal(t, "some team resolution updated", *policiesResponse.Policies[0].Resolution)
assert.Equal(t, "darwin", policiesResponse.Policies[0].Platform)
require.Len(t, policiesResponse.InheritedPolicies, 0)
listHostsURL := fmt.Sprintf("/api/latest/fleet/hosts?policy_id=%d", policiesResponse.Policies[0].ID)
listHostsResp := listHostsResponse{}

View file

@ -284,6 +284,15 @@ func (s *integrationEnterpriseTestSuite) TestTeamPolicies() {
ts := listTeamPoliciesResponse{}
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/teams/%d/policies", team1.ID), nil, http.StatusOK, &ts)
require.Len(t, ts.Policies, 0)
require.Len(t, ts.InheritedPolicies, 0)
// create a global policy
gpol, err := s.ds.NewGlobalPolicy(context.Background(), nil, fleet.PolicyPayload{Name: "TestGlobalPolicy", Query: "SELECT 1"})
require.NoError(t, err)
defer func() {
_, err := s.ds.DeleteGlobalPolicies(context.Background(), []uint{gpol.ID})
require.NoError(t, err)
}()
qr, err := s.ds.NewQuery(context.Background(), &fleet.Query{Name: "TestQuery2", Description: "Some description", Query: "select * from osquery;", ObserverCanRun: true})
require.NoError(t, err)
@ -303,6 +312,9 @@ func (s *integrationEnterpriseTestSuite) TestTeamPolicies() {
assert.Equal(t, "Some description", ts.Policies[0].Description)
require.NotNil(t, ts.Policies[0].Resolution)
assert.Equal(t, "some team resolution", *ts.Policies[0].Resolution)
require.Len(t, ts.InheritedPolicies, 1)
assert.Equal(t, gpol.Name, ts.InheritedPolicies[0].Name)
assert.Equal(t, gpol.ID, ts.InheritedPolicies[0].ID)
deletePolicyParams := deleteTeamPoliciesRequest{IDs: []uint{ts.Policies[0].ID}}
deletePolicyResp := deleteTeamPoliciesResponse{}

View file

@ -96,32 +96,33 @@ type listTeamPoliciesRequest struct {
}
type listTeamPoliciesResponse struct {
Policies []*fleet.Policy `json:"policies,omitempty"`
Err error `json:"error,omitempty"`
Policies []*fleet.Policy `json:"policies,omitempty"`
InheritedPolicies []*fleet.Policy `json:"inherited_policies,omitempty"`
Err error `json:"error,omitempty"`
}
func (r listTeamPoliciesResponse) error() error { return r.Err }
func listTeamPoliciesEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (interface{}, error) {
req := request.(*listTeamPoliciesRequest)
resp, err := svc.ListTeamPolicies(ctx, req.TeamID)
tmPols, inheritedPols, err := svc.ListTeamPolicies(ctx, req.TeamID)
if err != nil {
return listTeamPoliciesResponse{Err: err}, nil
}
return listTeamPoliciesResponse{Policies: resp}, nil
return listTeamPoliciesResponse{Policies: tmPols, InheritedPolicies: inheritedPols}, nil
}
func (svc Service) ListTeamPolicies(ctx context.Context, teamID uint) ([]*fleet.Policy, error) {
func (svc *Service) ListTeamPolicies(ctx context.Context, teamID uint) (teamPolicies, inheritedPolicies []*fleet.Policy, err error) {
if err := svc.authz.Authorize(ctx, &fleet.Policy{
PolicyData: fleet.PolicyData{
TeamID: ptr.Uint(teamID),
},
}, fleet.ActionRead); err != nil {
return nil, err
return nil, nil, err
}
if _, err := svc.ds.Team(ctx, teamID); err != nil {
return nil, ctxerr.Wrapf(ctx, err, "loading team %d", teamID)
return nil, nil, ctxerr.Wrapf(ctx, err, "loading team %d", teamID)
}
return svc.ds.ListTeamPolicies(ctx, teamID)

View file

@ -24,8 +24,8 @@ func TestTeamPoliciesAuth(t *testing.T) {
},
}, nil
}
ds.ListTeamPoliciesFunc = func(ctx context.Context, teamID uint) ([]*fleet.Policy, error) {
return nil, nil
ds.ListTeamPoliciesFunc = func(ctx context.Context, teamID uint) (tpol, ipol []*fleet.Policy, err error) {
return nil, nil, nil
}
ds.PoliciesByIDFunc = func(ctx context.Context, ids []uint) (map[uint]*fleet.Policy, error) {
return nil, nil
@ -149,7 +149,7 @@ func TestTeamPoliciesAuth(t *testing.T) {
})
checkAuthErr(t, tt.shouldFailWrite, err)
_, err = svc.ListTeamPolicies(ctx, 1)
_, _, err = svc.ListTeamPolicies(ctx, 1)
checkAuthErr(t, tt.shouldFailRead, err)
_, err = svc.GetTeamPolicyByIDQueries(ctx, 1, 1)