From 4f726a724c8a82765abe4b461a489593dfdf3fc5 Mon Sep 17 00:00:00 2001 From: Ian Littman Date: Thu, 14 Nov 2024 11:09:51 -0600 Subject: [PATCH] Allow Fleet Premium users to opt out of populating vulnerability details when populating software in the hosts list endpoint (#23710) #23078 This endpoint is drastically more efficient, and returns a much smaller response payload, when vulnerability details aren't returned, and vulnerability details can be looked up more efficiently in the /vulnerabilities/CVE-XXXX-YYYY endpoint as that endpoint returns the description once overall rather than once per host. # Checklist for submitter If some of the following don't apply, delete the relevant line. - [x] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/Committing-Changes.md#changes-files) for more information. - [x] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [x] Added/updated tests - [x] Manual QA for all new/changed functionality --- changes/23078-allow-skipping-vuln-details | 1 + server/fleet/hosts.go | 4 +++ server/service/hosts.go | 4 ++- server/service/hosts_test.go | 31 ++++++++++++++++++- server/service/integration_enterprise_test.go | 23 ++++++++++++++ server/service/transport.go | 9 ++++-- 6 files changed, 68 insertions(+), 4 deletions(-) create mode 100644 changes/23078-allow-skipping-vuln-details diff --git a/changes/23078-allow-skipping-vuln-details b/changes/23078-allow-skipping-vuln-details new file mode 100644 index 0000000000..7a29933976 --- /dev/null +++ b/changes/23078-allow-skipping-vuln-details @@ -0,0 +1 @@ +* Allowed skipping computationally heavy population of vulnerability details when populating host software on hosts list endpoint (`GET /api/latest/fleet/hosts`) when using Fleet Premium (`populate_software=without_vulnerability_descriptions`) \ No newline at end of file diff --git a/server/fleet/hosts.go b/server/fleet/hosts.go index 0ff7d1bf62..3ff287c9f1 100644 --- a/server/fleet/hosts.go +++ b/server/fleet/hosts.go @@ -200,6 +200,10 @@ type HostListOptions struct { // PopulateSoftware adds the `Software` field to all Hosts returned. PopulateSoftware bool + // PopulateSoftwareVulnerabilityDetails adds description, fix version, etc. fields to software vulnerabilities + // (this is a Premium feature that gets forced to false on Fleet Free) + PopulateSoftwareVulnerabilityDetails bool + // PopulatePolicies adds the `Policies` array field to all Hosts returned. PopulatePolicies bool diff --git a/server/service/hosts.go b/server/service/hosts.go index b8c66ba61e..33e32c4380 100644 --- a/server/service/hosts.go +++ b/server/service/hosts.go @@ -187,6 +187,8 @@ func (svc *Service) ListHosts(ctx context.Context, opt fleet.HostListOptions) ([ opt.LowDiskSpaceFilter = nil // the bootstrap package filter is premium-only opt.MDMBootstrapPackageFilter = nil + // including vulnerability details on software is premium-only + opt.PopulateSoftwareVulnerabilityDetails = false } hosts, err := svc.ds.ListHosts(ctx, filter, opt) @@ -210,7 +212,7 @@ func (svc *Service) ListHosts(ctx context.Context, opt fleet.HostListOptions) ([ if opt.PopulateSoftware { for _, host := range hosts { - if err = svc.ds.LoadHostSoftware(ctx, host, premiumLicense); err != nil { + if err = svc.ds.LoadHostSoftware(ctx, host, opt.PopulateSoftwareVulnerabilityDetails); err != nil { return nil, err } } diff --git a/server/service/hosts_test.go b/server/service/hosts_test.go index 9cdfacc6c4..5b0837cf60 100644 --- a/server/service/hosts_test.go +++ b/server/service/hosts_test.go @@ -803,7 +803,8 @@ func TestListHosts(t *testing.T) { }, nil } - hosts, err := svc.ListHosts(test.UserContext(ctx, test.UserAdmin), fleet.HostListOptions{}) + userContext := test.UserContext(ctx, test.UserAdmin) + hosts, err := svc.ListHosts(userContext, fleet.HostListOptions{}) require.NoError(t, err) require.Len(t, hosts, 1) @@ -811,6 +812,34 @@ func TestListHosts(t *testing.T) { _, err = svc.ListHosts(ctx, fleet.HostListOptions{}) require.Error(t, err) require.Contains(t, err.Error(), authz.ForbiddenErrorMessage) + + var shouldIncludeCVEScores bool + ds.LoadHostSoftwareFunc = func(ctx context.Context, host *fleet.Host, includeCVEScores bool) error { + require.Equal(t, shouldIncludeCVEScores, includeCVEScores) + return nil + } + + // free license disallows getting vuln details + hosts, err = svc.ListHosts(userContext, fleet.HostListOptions{PopulateSoftware: true, PopulateSoftwareVulnerabilityDetails: true}) + require.NoError(t, err) + require.Len(t, hosts, 1) + require.True(t, ds.LoadHostSoftwareFuncInvoked) + ds.LoadHostSoftwareFuncInvoked = false + + // you're allowed to skip vuln details on Premium + userContext = license.NewContext(userContext, &fleet.LicenseInfo{Tier: fleet.TierPremium}) + hosts, err = svc.ListHosts(userContext, fleet.HostListOptions{PopulateSoftware: true, PopulateSoftwareVulnerabilityDetails: false}) + require.NoError(t, err) + require.Len(t, hosts, 1) + require.True(t, ds.LoadHostSoftwareFuncInvoked) + ds.LoadHostSoftwareFuncInvoked = false + + // you're allowed to retrieve vuln details on Premium + shouldIncludeCVEScores = true + hosts, err = svc.ListHosts(userContext, fleet.HostListOptions{PopulateSoftware: true, PopulateSoftwareVulnerabilityDetails: true}) + require.NoError(t, err) + require.Len(t, hosts, 1) + require.True(t, ds.LoadHostSoftwareFuncInvoked) } func TestGetHostSummary(t *testing.T) { diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index 533e45a402..649d0b5912 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -3990,6 +3990,29 @@ func (s *integrationEnterpriseTestSuite) TestListHosts() { } } + resp = listHostsResponse{} + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "populate_software", "without_vulnerability_details") + require.Len(t, resp.Hosts, 3) + for _, h := range resp.Hosts { + if h.ID == host1.ID { + require.NotEmpty(t, h.Software) + require.Len(t, h.Software, 1) + require.NotEmpty(t, h.Software[0].Vulnerabilities) + + require.Nil(t, h.Software[0].Vulnerabilities[0].CVSSScore) + require.Nil(t, h.Software[0].Vulnerabilities[0].EPSSProbability) + require.Nil(t, h.Software[0].Vulnerabilities[0].CISAKnownExploit) + require.Nil(t, h.Software[0].Vulnerabilities[0].Description) + assert.Equal(t, uint64(1), h.HostIssues.FailingPoliciesCount) + assert.Equal(t, uint64(1), *h.HostIssues.CriticalVulnerabilitiesCount) + assert.Equal(t, uint64(2), h.HostIssues.TotalIssuesCount) + } else { + assert.Zero(t, h.HostIssues.FailingPoliciesCount) + assert.Zero(t, *h.HostIssues.CriticalVulnerabilitiesCount) + assert.Zero(t, h.HostIssues.TotalIssuesCount) + } + } + resp = listHostsResponse{} s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "populate_software", "false") require.Len(t, resp.Hosts, 3) diff --git a/server/service/transport.go b/server/service/transport.go index d2518e146c..b92b7e8724 100644 --- a/server/service/transport.go +++ b/server/service/transport.go @@ -564,13 +564,18 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error) hopt.LowDiskSpaceFilter = &v } populateSoftware := r.URL.Query().Get("populate_software") - if populateSoftware != "" { + if populateSoftware == "without_vulnerability_details" { + hopt.PopulateSoftware = true + hopt.PopulateSoftwareVulnerabilityDetails = false + } else if populateSoftware != "" { ps, err := strconv.ParseBool(populateSoftware) if err != nil { - return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid populate_software: %s", populateSoftware))) + return hopt, ctxerr.Wrap(r.Context(), badRequest(`Invalid value for populate_software. Should be one of "true", "false", or "without_vulnerability_details".`)) } hopt.PopulateSoftware = ps + hopt.PopulateSoftwareVulnerabilityDetails = ps } + populatePolicies := r.URL.Query().Get("populate_policies") if populatePolicies != "" { pp, err := strconv.ParseBool(populatePolicies)