From 74ccff8161f3cfdc755449830bb1c050a37797ec Mon Sep 17 00:00:00 2001 From: Tim Lee Date: Wed, 23 Aug 2023 10:34:55 -0600 Subject: [PATCH] 13433 host query optimization (#13451) --- changes/13433-host-query-optimization | 1 + ee/server/service/teams.go | 8 +++- server/datastore/mysql/hosts.go | 28 ++++------- server/datastore/mysql/policies.go | 47 ++++++++++++++++++ server/datastore/mysql/policies_test.go | 63 +++++++++++++++++++++++++ server/service/hosts.go | 2 + 6 files changed, 129 insertions(+), 20 deletions(-) create mode 100644 changes/13433-host-query-optimization diff --git a/changes/13433-host-query-optimization b/changes/13433-host-query-optimization new file mode 100644 index 0000000000..aa1fc4fab9 --- /dev/null +++ b/changes/13433-host-query-optimization @@ -0,0 +1 @@ +- optimized hosts queries when using policy statuses \ No newline at end of file diff --git a/ee/server/service/teams.go b/ee/server/service/teams.go index ac7e793896..b8014bb8ec 100644 --- a/ee/server/service/teams.go +++ b/ee/server/service/teams.go @@ -449,7 +449,13 @@ func (svc *Service) DeleteTeam(ctx context.Context, teamID uint) error { } filter := fleet.TeamFilter{User: vc.User, IncludeObserver: true} - hosts, err := svc.ds.ListHosts(ctx, filter, fleet.HostListOptions{TeamFilter: &teamID}) + + opts := fleet.HostListOptions{ + TeamFilter: &teamID, + DisableFailingPolicies: true, // don't need to check policies for hosts that are being deleted + } + + hosts, err := svc.ds.ListHosts(ctx, filter, opts) if err != nil { return ctxerr.Wrap(ctx, err, "list hosts for reconcile profiles on team change") } diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index 5f29b057c9..0f3f0269e1 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -833,15 +833,6 @@ func (ds *Datastore) ListHosts(ctx context.Context, filter fleet.TeamFilter, opt ` } - failingPoliciesSelect := `, - coalesce(failing_policies.count, 0) as failing_policies_count, - coalesce(failing_policies.count, 0) as total_issues_count - ` - if opt.DisableFailingPolicies { - failingPoliciesSelect = "" - } - sql += failingPoliciesSelect - var params []interface{} // Only include "additional" if filter provided. @@ -863,6 +854,7 @@ func (ds *Datastore) ListHosts(ctx context.Context, filter fleet.TeamFilter, opt ) FROM host_additional WHERE host_id = h.id) AS additional ` } + sql, params = ds.applyHostFilters(opt, sql, filter, params) hosts := []*fleet.Host{} @@ -870,6 +862,14 @@ func (ds *Datastore) ListHosts(ctx context.Context, filter fleet.TeamFilter, opt return nil, ctxerr.Wrap(ctx, err, "list hosts") } + if !opt.DisableFailingPolicies { + var err error + hosts, err = ds.UpdatePolicyFailureCountsForHosts(ctx, hosts) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "update policy failure counts for hosts") + } + } + return hosts, nil } @@ -902,14 +902,6 @@ func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, fil params = append(params, opt.SoftwareIDFilter) } - failingPoliciesJoin := `LEFT JOIN ( - SELECT host_id, count(*) as count FROM policy_membership WHERE passes = 0 - GROUP BY host_id - ) as failing_policies ON (h.id=failing_policies.host_id)` - if opt.DisableFailingPolicies { - failingPoliciesJoin = "" - } - operatingSystemJoin := "" if opt.OSIDFilter != nil || (opt.OSNameFilter != nil && opt.OSVersionFilter != nil) { operatingSystemJoin = `JOIN host_operating_system hos ON h.id = hos.host_id` @@ -944,7 +936,6 @@ func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, fil %s %s %s - %s %s WHERE TRUE AND %s AND %s AND %s AND %s `, @@ -953,7 +944,6 @@ func (ds *Datastore) applyHostFilters(opt fleet.HostListOptions, sql string, fil hostMDMJoin, deviceMappingJoin, policyMembershipJoin, - failingPoliciesJoin, operatingSystemJoin, munkiJoin, displayNameJoin, diff --git a/server/datastore/mysql/policies.go b/server/datastore/mysql/policies.go index 83edb7c8d3..b2c90b91ba 100644 --- a/server/datastore/mysql/policies.go +++ b/server/datastore/mysql/policies.go @@ -747,6 +747,53 @@ func (ds *Datastore) CleanupPolicyMembership(ctx context.Context, now time.Time) return nil } +func (ds *Datastore) UpdatePolicyFailureCountsForHosts(ctx context.Context, hosts []*fleet.Host) ([]*fleet.Host, error) { + // Get policy failure counts for each host + hostIDs := make([]uint, 0, len(hosts)) + + for _, host := range hosts { + hostIDs = append(hostIDs, host.ID) + } + + query, args, err := sqlx.In(` + SELECT + pm.host_id, + COUNT(*) AS failing_policy_count + FROM + policy_membership pm + WHERE + pm.passes = 0 AND + pm.host_id IN (?) + GROUP BY + pm.host_id + `, hostIDs) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "build policy failure count query") + } + + var policyFailureCounts []struct { + HostID uint `db:"host_id"` + FailingPolicyCount int `db:"failing_policy_count"` + } + + if err := sqlx.SelectContext(ctx, ds.reader(ctx), &policyFailureCounts, query, args...); err != nil { + return nil, ctxerr.Wrap(ctx, err, "get policy failure counts for hosts") + } + + // Map policy failure counts to hosts + hostIDToPolicyFailureCounts := make(map[uint]int) + for _, policyFailureCount := range policyFailureCounts { + hostIDToPolicyFailureCounts[policyFailureCount.HostID] = policyFailureCount.FailingPolicyCount + } + + for _, host := range hosts { + host.TotalIssuesCount = hostIDToPolicyFailureCounts[host.ID] + host.FailingPoliciesCount = hostIDToPolicyFailureCounts[host.ID] + } + + return hosts, nil +} + // PolicyViolationDays is a structure used for aggregate counts of policy violation days. type PolicyViolationDays struct { // FailingHostCount is an aggregate count of actual policy violations days. One actual policy diff --git a/server/datastore/mysql/policies_test.go b/server/datastore/mysql/policies_test.go index f838047881..fc7a6c4510 100644 --- a/server/datastore/mysql/policies_test.go +++ b/server/datastore/mysql/policies_test.go @@ -48,6 +48,7 @@ func TestPolicies(t *testing.T) { {"PolicyViolationDays", testPolicyViolationDays}, {"IncreasePolicyAutomationIteration", testIncreasePolicyAutomationIteration}, {"OutdatedAutomationBatch", testOutdatedAutomationBatch}, + {"TestUpdatePolicyFailureCountsForHosts", testUpdatePolicyFailureCountsForHosts}, {"TestPolicyIDsByName", testPolicyByName}, } for _, c := range cases { @@ -2209,3 +2210,65 @@ func testOutdatedAutomationBatch(t *testing.T, ds *Datastore) { require.NoError(t, err) require.ElementsMatch(t, batch, []fleet.PolicyFailure{}) } + +func testUpdatePolicyFailureCountsForHosts(t *testing.T, ds *Datastore) { + ctx := context.Background() + + // create 4 hosts + var hosts []*fleet.Host + for i := 0; i < 4; i++ { + h, err := ds.NewHost(ctx, &fleet.Host{OsqueryHostID: ptr.String(fmt.Sprintf("host%d", i)), NodeKey: ptr.String(fmt.Sprintf("host%d", i))}) + require.NoError(t, err) + hosts = append(hosts, h) + } + + // create 2 policies + var pols []*fleet.Policy + for i := 0; i < 2; i++ { + p, err := ds.NewGlobalPolicy(ctx, nil, fleet.PolicyPayload{Name: fmt.Sprintf("policy%d", i)}) + require.NoError(t, err) + pols = append(pols, p) + } + + // create policy membership for hosts + _, err := ds.writer(ctx).ExecContext(ctx, ` + INSERT INTO policy_membership (policy_id, host_id, passes) + VALUES + (?, ?, 1), + (?, ?, 1), + (?, ?, 0), + (?, ?, 0), + (?, ?, 1), + (?, ?, 0) + `, + pols[0].ID, hosts[0].ID, + pols[0].ID, hosts[1].ID, + pols[0].ID, hosts[2].ID, + pols[1].ID, hosts[0].ID, + pols[1].ID, hosts[1].ID, + pols[1].ID, hosts[2].ID, + ) + + require.NoError(t, err) + + // update policy failure counts for hosts + hostsUpdated, err := ds.UpdatePolicyFailureCountsForHosts(ctx, hosts) + require.NoError(t, err) + require.Len(t, hostsUpdated, 4) + + // host 0 should have 1 failing policy + assert.Equal(t, 1, hostsUpdated[0].TotalIssuesCount) + assert.Equal(t, 1, hostsUpdated[0].FailingPoliciesCount) + + // host 1 should have 0 failing policies + assert.Equal(t, 0, hostsUpdated[1].TotalIssuesCount) + assert.Equal(t, 0, hostsUpdated[1].FailingPoliciesCount) + + // host 2 should have 2 failing policies + assert.Equal(t, 2, hostsUpdated[2].TotalIssuesCount) + assert.Equal(t, 2, hostsUpdated[2].FailingPoliciesCount) + + // host 3 doesn't have any policy membership + assert.Equal(t, 0, hostsUpdated[3].TotalIssuesCount) + assert.Equal(t, 0, hostsUpdated[3].FailingPoliciesCount) +} diff --git a/server/service/hosts.go b/server/service/hosts.go index 49072b9c78..de0e7d3c68 100644 --- a/server/service/hosts.go +++ b/server/service/hosts.go @@ -210,6 +210,7 @@ func (svc *Service) DeleteHosts(ctx context.Context, ids []uint, opts fleet.Host return svc.ds.DeleteHosts(ctx, ids) } + opts.DisableFailingPolicies = true // don't check policies for hosts that are about to be deleted hostIDs, _, err := svc.hostIDsAndNamesFromFilters(ctx, opts, lid) if err != nil { return err @@ -961,6 +962,7 @@ func (svc *Service) hostIDsAndNamesFromFilters(ctx context.Context, opt fleet.Ho if lid != nil { hosts, err = svc.ds.ListHostsInLabel(ctx, filter, *lid, opt) } else { + opt.DisableFailingPolicies = true // intentionally ignore failing policies hosts, err = svc.ds.ListHosts(ctx, filter, opt) } if err != nil {