From eb9a1df045efad98b64e417623fbe256da59c58d Mon Sep 17 00:00:00 2001 From: Tim Lee Date: Thu, 15 Feb 2024 13:27:18 -0700 Subject: [PATCH] Host Vulnerability Filter (#16889) --- server/datastore/mysql/hosts.go | 20 +++++++ server/datastore/mysql/hosts_test.go | 79 ++++++++++++++++++++++++++++ server/fleet/hosts.go | 3 ++ server/service/transport.go | 5 ++ server/service/transport_test.go | 3 +- 5 files changed, 109 insertions(+), 1 deletion(-) diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index af68b08f80..0be8ff8984 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -1098,6 +1098,7 @@ func (ds *Datastore) applyHostFilters( sqlStmt, params = filterHostsByMDMBootstrapPackageStatus(sqlStmt, opt, params) sqlStmt, params = filterHostsByOS(sqlStmt, opt, params) + sqlStmt, params = filterHostsByVulnerability(sqlStmt, opt, params) sqlStmt, params, _ = hostSearchLike(sqlStmt, params, opt.MatchQuery, append(hostSearchColumns, "display_name")...) sqlStmt, params = appendListOptionsWithCursorToSQL(sqlStmt, params, &opt.ListOptions) @@ -1500,6 +1501,25 @@ func filterHostsByMDMBootstrapPackageStatus(sql string, opt fleet.HostListOption return sql + newSQL, params } +func filterHostsByVulnerability(sqlstmt string, opt fleet.HostListOptions, params []interface{}) (string, []interface{}) { + if opt.VulnerabilityFilter != nil { + sqlstmt += ` AND h.id IN ( + SELECT hs.host_id FROM host_software hs + JOIN software_cve sc ON sc.software_id = hs.software_id + WHERE sc.cve = ? + + UNION + + SELECT hos.host_id FROM host_operating_system hos + JOIN operating_system_vulnerabilities osv ON osv.operating_system_id = hos.os_id + WHERE osv.cve = ?)` + + params = append(params, opt.VulnerabilityFilter, opt.VulnerabilityFilter) + } + + return sqlstmt, params +} + func (ds *Datastore) CountHosts(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) (int, error) { sql := `SELECT count(*) ` diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index 8fb9aeaefb..dd8be6b6a8 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -114,6 +114,7 @@ func TestHosts(t *testing.T) { {"HostsListBySoftwareChangedAt", testHostsListBySoftwareChangedAt}, {"HostsListByOperatingSystemID", testHostsListByOperatingSystemID}, {"HostsListByOSNameAndVersion", testHostsListByOSNameAndVersion}, + {"HostsListByVulnerability", testHostsListByVulnerability}, {"HostsListByDiskEncryptionStatus", testHostsListMacOSSettingsDiskEncryptionStatus}, {"HostsListFailingPolicies", printReadsInTest(testHostsListFailingPolicies)}, {"HostsExpiration", testHostsExpiration}, @@ -3143,6 +3144,84 @@ func testHostsListByOSNameAndVersion(t *testing.T, ds *Datastore) { } } +func testHostsListByVulnerability(t *testing.T, ds *Datastore) { + // seed hosts + var hosts []*fleet.Host + for i := 0; i < 9; i++ { + h, err := ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now().Add(-time.Duration(i) * time.Minute), + OsqueryHostID: ptr.String(strconv.Itoa(i)), + NodeKey: ptr.String(fmt.Sprintf("%d", i)), + UUID: fmt.Sprintf("%d", i), + Hostname: fmt.Sprintf("foo.local%d", i), + }) + require.NoError(t, err) + hosts = append(hosts, h) + } + + // seed software + software := []fleet.Software{ + {Name: "foo", Version: "0.0.2", Source: "chrome_extensions"}, + } + + // add software to 5 hosts + var swVulnHostIDs []uint + for i := 0; i < 5; i++ { + _, err := ds.UpdateHostSoftware(context.Background(), hosts[i].ID, software) + require.NoError(t, err) + swVulnHostIDs = append(swVulnHostIDs, hosts[i].ID) + } + + // seed software vulnerabilities + vuln := fleet.SoftwareVulnerability{ + CVE: "CVE-2021-1234", + SoftwareID: 1, + } + + _, err := ds.InsertSoftwareVulnerability(context.Background(), vuln, fleet.NVDSource) + require.NoError(t, err) + + list, err := ds.ListHosts(context.Background(), fleet.TeamFilter{User: test.UserAdmin}, fleet.HostListOptions{VulnerabilityFilter: ptr.String("CVE-2021-1234")}) + require.NoError(t, err) + require.Len(t, list, 5) + for _, h := range list { + require.Contains(t, swVulnHostIDs, h.ID) + } + + // update 2 host operating system + os := fleet.OperatingSystem{ + Name: "Ubuntu", + Version: "20.4.0 LTS", + Arch: "x86_64", + Platform: "ubuntu", + KernelVersion: "5.10.76-linuxkit", + } + err = ds.UpdateHostOperatingSystem(context.Background(), hosts[0].ID, os) + require.NoError(t, err) + err = ds.UpdateHostOperatingSystem(context.Background(), hosts[1].ID, os) + require.NoError(t, err) + + // seed os vulnerability + osVulns := []fleet.OSVulnerability{ + { + OSID: 1, + CVE: "CVE-2021-1235", + }, + } + _, err = ds.InsertOSVulnerabilities(context.Background(), osVulns, fleet.NVDSource) + require.NoError(t, err) + + list, err = ds.ListHosts(context.Background(), fleet.TeamFilter{User: test.UserAdmin}, fleet.HostListOptions{VulnerabilityFilter: ptr.String("CVE-2021-1235")}) + require.NoError(t, err) + require.Len(t, list, 2) + for _, h := range list { + require.Contains(t, []uint{hosts[0].ID, hosts[1].ID}, h.ID) + } +} + func testHostsListMacOSSettingsDiskEncryptionStatus(t *testing.T, ds *Datastore) { ctx := context.Background() diff --git a/server/fleet/hosts.go b/server/fleet/hosts.go index dd709ec7c8..8f3b24cbab 100644 --- a/server/fleet/hosts.go +++ b/server/fleet/hosts.go @@ -181,6 +181,9 @@ type HostListOptions struct { // PopulateSoftware adds the `Software` field to all Hosts returned. PopulateSoftware bool + + // VulnerabilityFilter filters the hosts by the presence of a vulnerability (CVE) + VulnerabilityFilter *string } // TODO(Sarah): Are we missing any filters here? Should all MDM filters be included? diff --git a/server/service/transport.go b/server/service/transport.go index 44195aae9f..a0ba0a1124 100644 --- a/server/service/transport.go +++ b/server/service/transport.go @@ -324,6 +324,11 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error) hopt.OSVersionFilter = &osVersion } + cve := r.URL.Query().Get("vulnerability") + if cve != "" { + hopt.VulnerabilityFilter = &cve + } + if hopt.OSNameFilter != nil && hopt.OSVersionFilter == nil { return hopt, ctxerr.Wrap( r.Context(), badRequest( diff --git a/server/service/transport_test.go b/server/service/transport_test.go index beff512280..ce2e605f7e 100644 --- a/server/service/transport_test.go +++ b/server/service/transport_test.go @@ -158,7 +158,7 @@ func TestHostListOptionsFromRequest(t *testing.T) { "&os_name=osName&os_version=osVersion&os_version_id=5&disable_failing_policies=1&macos_settings=verified" + "&macos_settings_disk_encryption=enforcing&os_settings=pending&os_settings_disk_encryption=failed" + "&bootstrap_package=installed&mdm_id=6&mdm_name=mdmName&mdm_enrollment_status=automatic" + - "&munki_issue_id=7&low_disk_space=99", + "&munki_issue_id=7&low_disk_space=99&vulnerability=CVE-2023-42887", hostListOptions: fleet.HostListOptions{ ListOptions: fleet.ListOptions{ OrderKey: "foo", @@ -188,6 +188,7 @@ func TestHostListOptionsFromRequest(t *testing.T) { MDMEnrollmentStatusFilter: fleet.MDMEnrollStatusAutomatic, MunkiIssueIDFilter: ptr.Uint(7), LowDiskSpaceFilter: ptr.Int(99), + VulnerabilityFilter: ptr.String("CVE-2023-42887"), }, }, "policy_id and policy_response params (for coverage)": {