diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index feb6272eeb..df1ca09310 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -451,6 +451,25 @@ func (ds *Datastore) CountPolicies(ctx context.Context, teamID *uint, matchQuery return count, nil } +func (ds *Datastore) CountMergedTeamPolicies(ctx context.Context, teamID uint, matchQuery string) (int, error) { + var args []interface{} + + query := `SELECT count(*) FROM policies p WHERE p.team_id = ? OR p.team_id IS NULL` + args = append(args, teamID) + + // We must normalize the name for full Unicode support (Unicode equivalence). + match := norm.NFC.String(matchQuery) + query, args = searchLike(query, args, match, policySearchColumns...) + + var count int + err := sqlx.GetContext(ctx, ds.reader(ctx), &count, query, args...) + if err != nil { + return 0, ctxerr.Wrap(ctx, err, "counting merged team policies") + } + + return count, nil +} + func (ds *Datastore) PoliciesByID(ctx context.Context, ids []uint) (map[uint]*fleet.Policy, error) { sql := `SELECT ` + policyCols + `, COALESCE(u.name, '') AS author_name, diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index 57514af79b..a784f3a37b 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -2954,6 +2954,10 @@ func testCountPolicies(t *testing.T, ds *Datastore) { require.NoError(t, err) assert.Equal(t, 0, teamCount) + mergedCount, err := ds.CountMergedTeamPolicies(ctx, tm.ID, "") + require.NoError(t, err) + assert.Equal(t, 0, mergedCount) + // 10 global policies for i := 0; i < 10; i++ { _, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: fmt.Sprintf("global policy %d", i)}) @@ -2968,6 +2972,10 @@ func testCountPolicies(t *testing.T, ds *Datastore) { require.NoError(t, err) assert.Equal(t, 0, teamCount) + mergedCount, err = ds.CountMergedTeamPolicies(ctx, tm.ID, "") + require.NoError(t, err) + assert.Equal(t, 10, mergedCount) + // add 5 team policies for i := 0; i < 5; i++ { _, err := ds.NewTeamPolicy(ctx, tm.ID, nil, fleet.PolicyPayload{Name: fmt.Sprintf("team policy %d", i)}) @@ -2981,6 +2989,10 @@ func testCountPolicies(t *testing.T, ds *Datastore) { globalCount, err = ds.CountPolicies(ctx, nil, "") require.NoError(t, err) assert.Equal(t, 10, globalCount) + + mergedCount, err = ds.CountMergedTeamPolicies(ctx, tm.ID, "") + require.NoError(t, err) + assert.Equal(t, 15, mergedCount) } func testUpdatePolicyHostCounts(t *testing.T, ds *Datastore) { diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 05c1e6cd96..5d1bf64534 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -612,6 +612,7 @@ type Datastore interface { PoliciesByID(ctx context.Context, ids []uint) (map[uint]*Policy, error) DeleteGlobalPolicies(ctx context.Context, ids []uint) ([]uint, error) CountPolicies(ctx context.Context, teamID *uint, matchQuery string) (int, error) + CountMergedTeamPolicies(ctx context.Context, teamID uint, matchQuery string) (int, error) UpdateHostPolicyCounts(ctx context.Context) error PolicyQueriesForHost(ctx context.Context, host *Host) (map[string]string, error) diff --git a/server/fleet/service.go b/server/fleet/service.go index d47b60fcb3..f5616443ea 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -646,7 +646,7 @@ type Service interface { DeleteTeamPolicies(ctx context.Context, teamID uint, ids []uint) ([]uint, error) ModifyTeamPolicy(ctx context.Context, teamID uint, id uint, p ModifyPolicyPayload) (*Policy, error) GetTeamPolicyByIDQueries(ctx context.Context, teamID uint, policyID uint) (*Policy, error) - CountTeamPolicies(ctx context.Context, teamID uint, matchQuery string) (int, error) + CountTeamPolicies(ctx context.Context, teamID uint, matchQuery string, mergeInherited bool) (int, error) // ///////////////////////////////////////////////////////////////////////////// // Geolocation diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 50ab787cac..834bb4499c 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -447,6 +447,8 @@ type DeleteGlobalPoliciesFunc func(ctx context.Context, ids []uint) ([]uint, err type CountPoliciesFunc func(ctx context.Context, teamID *uint, matchQuery string) (int, error) +type CountMergedTeamPoliciesFunc func(ctx context.Context, teamID uint, matchQuery string) (int, error) + type UpdateHostPolicyCountsFunc func(ctx context.Context) error type PolicyQueriesForHostFunc func(ctx context.Context, host *fleet.Host) (map[string]string, error) @@ -1566,6 +1568,9 @@ type DataStore struct { CountPoliciesFunc CountPoliciesFunc CountPoliciesFuncInvoked bool + CountMergedTeamPoliciesFunc CountMergedTeamPoliciesFunc + CountMergedTeamPoliciesFuncInvoked bool + UpdateHostPolicyCountsFunc UpdateHostPolicyCountsFunc UpdateHostPolicyCountsFuncInvoked bool @@ -3781,6 +3786,13 @@ func (s *DataStore) CountPolicies(ctx context.Context, teamID *uint, matchQuery return s.CountPoliciesFunc(ctx, teamID, matchQuery) } +func (s *DataStore) CountMergedTeamPolicies(ctx context.Context, teamID uint, matchQuery string) (int, error) { + s.mu.Lock() + s.CountMergedTeamPoliciesFuncInvoked = true + s.mu.Unlock() + return s.CountMergedTeamPoliciesFunc(ctx, teamID, matchQuery) +} + func (s *DataStore) UpdateHostPolicyCounts(ctx context.Context) error { s.mu.Lock() s.UpdateHostPolicyCountsFuncInvoked = true diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index 7c043d7dcd..d7905364f5 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -844,6 +844,16 @@ func (s *integrationEnterpriseTestSuite) TestTeamPolicies() { assert.Equal(t, gpol.Name, ts.InheritedPolicies[0].Name) assert.Equal(t, gpol.ID, ts.InheritedPolicies[0].ID) + tc := countTeamPoliciesResponse{} + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/teams/%d/policies/count", team1.ID), nil, http.StatusOK, &tc) + require.Nil(t, tc.Err) + require.Equal(t, 1, tc.Count) + + gc := countGlobalPoliciesResponse{} + s.DoJSON("GET", "/api/latest/fleet/policies/count", nil, http.StatusOK, &gc) + require.Nil(t, gc.Err) + require.Equal(t, 1, gc.Count) + // Test merge inherited ts = listTeamPoliciesResponse{} s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/teams/%d/policies", team1.ID), nil, http.StatusOK, &ts, "merge_inherited", "true", "order_key", "team_id", "order_direction", "desc") @@ -857,6 +867,11 @@ func (s *integrationEnterpriseTestSuite) TestTeamPolicies() { assert.Equal(t, gpol.Name, ts.Policies[1].Name) assert.Equal(t, gpol.ID, ts.Policies[1].ID) + countResp := countTeamPoliciesResponse{} + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/teams/%d/policies/count", team1.ID), nil, http.StatusOK, &countResp, "merge_inherited", "true") + require.Nil(t, countResp.Err) + require.Equal(t, 2, countResp.Count) + // Test delete deletePolicyParams := deleteTeamPoliciesRequest{IDs: []uint{ts.Policies[0].ID}} deletePolicyResp := deleteTeamPoliciesResponse{} diff --git a/server/service/team_policies.go b/server/service/team_policies.go index ee55fa2a7f..a2698c145e 100644 --- a/server/service/team_policies.go +++ b/server/service/team_policies.go @@ -160,8 +160,9 @@ func (svc *Service) ListTeamPolicies(ctx context.Context, teamID uint, opts flee ///////////////////////////////////////////////////////////////////////////////// type countTeamPoliciesRequest struct { - ListOptions fleet.ListOptions `url:"list_options"` - TeamID uint `url:"team_id"` + ListOptions fleet.ListOptions `url:"list_options"` + TeamID uint `url:"team_id"` + MergeInherited bool `query:"merge_inherited,optional"` } type countTeamPoliciesResponse struct { @@ -173,14 +174,14 @@ func (r countTeamPoliciesResponse) error() error { return r.Err } func countTeamPoliciesEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) { req := request.(*countTeamPoliciesRequest) - resp, err := svc.CountTeamPolicies(ctx, req.TeamID, req.ListOptions.MatchQuery) + resp, err := svc.CountTeamPolicies(ctx, req.TeamID, req.ListOptions.MatchQuery, req.MergeInherited) if err != nil { return countTeamPoliciesResponse{Err: err}, nil } return countTeamPoliciesResponse{Count: resp}, nil } -func (svc *Service) CountTeamPolicies(ctx context.Context, teamID uint, matchQuery string) (int, error) { +func (svc *Service) CountTeamPolicies(ctx context.Context, teamID uint, matchQuery string, mergeInherited bool) (int, error) { if err := svc.authz.Authorize(ctx, &fleet.Policy{ PolicyData: fleet.PolicyData{ TeamID: ptr.Uint(teamID), @@ -193,6 +194,10 @@ func (svc *Service) CountTeamPolicies(ctx context.Context, teamID uint, matchQue return 0, ctxerr.Wrapf(ctx, err, "loading team %d", teamID) } + if mergeInherited { + return svc.ds.CountMergedTeamPolicies(ctx, teamID, matchQuery) + } + return svc.ds.CountPolicies(ctx, &teamID, matchQuery) }