From f1eeaf42f2b0e52135a65e591e3a96344301d26a Mon Sep 17 00:00:00 2001 From: Tim Lee Date: Fri, 9 Feb 2024 20:54:44 -0700 Subject: [PATCH] 2 of 2: List Vulnerabilities API (#16695) --- ee/server/service/vulnerabilities.go | 22 +++ server/datastore/mysql/vulnerabilities.go | 60 +++++- .../datastore/mysql/vulnerabilities_test.go | 32 ++- server/fleet/service.go | 8 + server/fleet/vulnerabilities.go | 38 ++-- server/service/handler.go | 4 + server/service/integration_core_test.go | 150 ++++++++++++++ server/service/integration_enterprise_test.go | 185 +++++++++++++++++- server/service/vulnerabilities.go | 92 +++++++++ server/service/vulnerabilities_test.go | 53 +++++ 10 files changed, 617 insertions(+), 27 deletions(-) create mode 100644 ee/server/service/vulnerabilities.go create mode 100644 server/service/vulnerabilities.go create mode 100644 server/service/vulnerabilities_test.go diff --git a/ee/server/service/vulnerabilities.go b/ee/server/service/vulnerabilities.go new file mode 100644 index 0000000000..41f435b06f --- /dev/null +++ b/ee/server/service/vulnerabilities.go @@ -0,0 +1,22 @@ +package service + +import ( + "context" + + "github.com/fleetdm/fleet/v4/server/fleet" +) + +var eeValidVulnSortColumns = []string{ + "cve", + "host_count", + "created_at", + "cvss_score", + "epss_probability", + "published", +} + +func (svc *Service) ListVulnerabilities(ctx context.Context, opt fleet.VulnListOptions) ([]fleet.VulnerabilityWithMetadata, *fleet.PaginationMetadata, error) { + opt.ValidSortColumns = eeValidVulnSortColumns + opt.IsEE = true + return svc.Service.ListVulnerabilities(ctx, opt) +} diff --git a/server/datastore/mysql/vulnerabilities.go b/server/datastore/mysql/vulnerabilities.go index 9ba051710b..26abc36a83 100644 --- a/server/datastore/mysql/vulnerabilities.go +++ b/server/datastore/mysql/vulnerabilities.go @@ -11,7 +11,8 @@ import ( ) func (ds *Datastore) ListVulnerabilities(ctx context.Context, opt fleet.VulnListOptions) ([]fleet.VulnerabilityWithMetadata, *fleet.PaginationMetadata, error) { - selectStmt := ` + // Define base select statements for EE and Free versions + eeSelectStmt := ` SELECT vhc.cve, MIN(COALESCE(osv.created_at, sc.created_at, NOW())) AS created_at, @@ -21,7 +22,8 @@ func (ds *Datastore) ListVulnerabilities(ctx context.Context, opt fleet.VulnList cm.cisa_known_exploit, cm.published, COALESCE(cm.description, '') AS description, - vhc.host_count + vhc.host_count, + vhc.updated_at as host_count_updated_at FROM vulnerability_host_counts vhc LEFT JOIN cve_meta cm ON cm.cve = vhc.cve @@ -29,8 +31,32 @@ func (ds *Datastore) ListVulnerabilities(ctx context.Context, opt fleet.VulnList LEFT JOIN software_cve sc ON sc.cve = vhc.cve WHERE vhc.host_count > 0 ` - groupByAppend := ` GROUP BY + freeSelectStmt := ` + SELECT + vhc.cve, + MIN(COALESCE(osv.created_at, sc.created_at, NOW())) AS created_at, + COALESCE(osv.source, sc.source, 0) AS source, + vhc.host_count, + vhc.updated_at as host_count_updated_at + FROM + vulnerability_host_counts vhc + LEFT JOIN operating_system_vulnerabilities osv ON osv.cve = vhc.cve + LEFT JOIN software_cve sc ON sc.cve = vhc.cve + WHERE vhc.host_count > 0 + ` + + // Choose the appropriate select statement based on EE or Free + var selectStmt string + if opt.IsEE { + selectStmt = eeSelectStmt + } else { + selectStmt = freeSelectStmt + } + + // Define group by statements for EE and Free + eeGroupBy := ` GROUP BY vhc.cve, + source, cm.cvss_score, cm.epss_probability, cm.cisa_known_exploit, @@ -38,15 +64,33 @@ func (ds *Datastore) ListVulnerabilities(ctx context.Context, opt fleet.VulnList description, vhc.host_count ` + freeGroupBy := " GROUP BY vhc.cve, source, vhc.host_count" + // Choose the appropriate group by statement based on EE or Free + var groupBy string + if opt.IsEE { + groupBy = eeGroupBy + } else { + groupBy = freeGroupBy + } + + // Prepare arguments for the query var args []interface{} if opt.TeamID == 0 { - selectStmt = selectStmt + " AND vhc.team_id = 0" + selectStmt += " AND vhc.team_id = 0" } else { - selectStmt = selectStmt + " AND vhc.team_id = ?" + selectStmt += " AND vhc.team_id = ?" args = append(args, opt.TeamID) } + if opt.KnownExploit { + selectStmt += " AND cm.cisa_known_exploit = 1" + } + + if match := opt.MatchQuery; match != "" { + selectStmt, args = searchLike(selectStmt, args, match, "vhc.cve") + } + if opt.KnownExploit { selectStmt = selectStmt + " AND cm.cisa_known_exploit = 1" } @@ -55,17 +99,19 @@ func (ds *Datastore) ListVulnerabilities(ctx context.Context, opt fleet.VulnList selectStmt, args = searchLike(selectStmt, args, match, "vhc.cve") } - selectStmt = selectStmt + groupByAppend + // Append group by statement + selectStmt += groupBy opt.ListOptions.IncludeMetadata = !(opt.ListOptions.UsesCursorPagination()) - selectStmt, args = appendListOptionsWithCursorToSQL(selectStmt, args, &opt.ListOptions) + // Execute the query var vulns []fleet.VulnerabilityWithMetadata if err := sqlx.SelectContext(ctx, ds.reader(ctx), &vulns, selectStmt, args...); err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "list vulnerabilities") } + // Prepare metadata var metaData *fleet.PaginationMetadata if opt.ListOptions.IncludeMetadata { metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0} diff --git a/server/datastore/mysql/vulnerabilities_test.go b/server/datastore/mysql/vulnerabilities_test.go index f19d9c1afd..34954273f7 100644 --- a/server/datastore/mysql/vulnerabilities_test.go +++ b/server/datastore/mysql/vulnerabilities_test.go @@ -118,7 +118,35 @@ func testListVulnerabilities(t *testing.T, ds *Datastore) { Source: fleet.NVDSource, }, } - list, _, err = ds.ListVulnerabilities(context.Background(), opts) + list, _, err = ds.ListVulnerabilities(context.Background(), fleet.VulnListOptions{IsEE: true}) + require.NoError(t, err) + require.Len(t, list, 3) + for _, vuln := range list { + expectedVuln, ok := expected[vuln.CVE] + require.True(t, ok) + require.Equal(t, expectedVuln.CVEMeta, vuln.CVEMeta) + require.Equal(t, expectedVuln.HostCount, vuln.HostCount) + } + + // Test Fleet Free + expected = map[string]fleet.VulnerabilityWithMetadata{ + "CVE-2020-1234": { + CVEMeta: fleet.CVEMeta{CVE: "CVE-2020-1234"}, + HostCount: 10, + Source: fleet.MSRCSource, + }, + "CVE-2020-1235": { + CVEMeta: fleet.CVEMeta{CVE: "CVE-2020-1235"}, + HostCount: 15, + Source: fleet.MSRCSource, + }, + "CVE-2020-1236": { + CVEMeta: fleet.CVEMeta{CVE: "CVE-2020-1236"}, + HostCount: 20, + Source: fleet.NVDSource, + }, + } + list, _, err = ds.ListVulnerabilities(context.Background(), fleet.VulnListOptions{}) require.NoError(t, err) require.Len(t, list, 3) for _, vuln := range list { @@ -185,6 +213,7 @@ func testListVulnerabilitiesSort(t *testing.T, ds *Datastore) { seedVulnerabilities(t, ds) opts := fleet.VulnListOptions{ + IsEE: true, ListOptions: fleet.ListOptions{ Page: 0, PerPage: 5, @@ -219,6 +248,7 @@ func testVulnerabilitiesFilters(t *testing.T, ds *Datastore) { // Test KnownExploit filter opts := fleet.VulnListOptions{ + IsEE: true, KnownExploit: true, } list, _, err := ds.ListVulnerabilities(context.Background(), opts) diff --git a/server/fleet/service.go b/server/fleet/service.go index e0e870d20c..1cf71ddea0 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -605,6 +605,14 @@ type Service interface { ListSoftwareTitles(ctx context.Context, opt SoftwareTitleListOptions) ([]SoftwareTitle, int, *PaginationMetadata, error) SoftwareTitleByID(ctx context.Context, id uint) (*SoftwareTitle, error) + // ///////////////////////////////////////////////////////////////////////////// + // Vulnerabilities + + // ListVulnerabilities returns a list of vulnerabilities based on the provided options. + ListVulnerabilities(ctx context.Context, opt VulnListOptions) ([]VulnerabilityWithMetadata, *PaginationMetadata, error) + // CountVulnerabilities returns the number of vulnerabilities based on the provided options. + CountVulnerabilities(ctx context.Context, opt VulnListOptions) (uint, error) + // ///////////////////////////////////////////////////////////////////////////// // Team Policies diff --git a/server/fleet/vulnerabilities.go b/server/fleet/vulnerabilities.go index 32e577ac79..554b4e2233 100644 --- a/server/fleet/vulnerabilities.go +++ b/server/fleet/vulnerabilities.go @@ -21,22 +21,22 @@ type CVE struct { } type CVEMeta struct { - CVE string `db:"cve"` + CVE string `db:"cve" json:"cve"` // CVSSScore is the Common Vulnerability Scoring System (CVSS) base score v3. The base score ranges from 0 - 10 and // takes into account several different metrics. // See https://nvd.nist.gov/vuln-metrics/cvss. - CVSSScore *float64 `db:"cvss_score"` + CVSSScore *float64 `db:"cvss_score" json:"cvss_score,omitempty"` // EPSSProbability is the Exploit Prediction Scoring System (EPSS) score. It is the probability // that a software vulnerability will be exploited in the next 30 days. // See https://www.first.org/epss/. - EPSSProbability *float64 `db:"epss_probability"` + EPSSProbability *float64 `db:"epss_probability" json:"epss_probability,omitempty"` // CISAKnownExploit is whether the the software vulnerability is a known exploit according to CISA. // See https://www.cisa.gov/known-exploited-vulnerabilities. - CISAKnownExploit *bool `db:"cisa_known_exploit"` + CISAKnownExploit *bool `db:"cisa_known_exploit" json:"cisa_known_exploit,omitempty"` // Published is when the cve was published according to NIST.score - Published *time.Time `db:"published"` + Published *time.Time `db:"published" json:"published,omitempty"` // CVE text description - Description string `db:"description"` + Description string `db:"description" json:"description,omitempty"` } // SoftwareCPE represents an entry in the `software_cpe` table. @@ -129,15 +129,29 @@ const ( type VulnerabilityWithMetadata struct { CVEMeta - HostCount uint `db:"host_count"` - HostCountUpdatedAt time.Time `db:"host_count_updated_at"` - CreatedAt time.Time `db:"created_at"` - Source VulnerabilitySource `db:"source"` + HostCount uint `db:"host_count" json:"host_count"` + HostCountUpdatedAt time.Time `db:"host_count_updated_at" json:"host_count_updated_at"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + DetailsLink string `json:"details_link"` + Source VulnerabilitySource `db:"source" json:"-"` } type VulnListOptions struct { ListOptions + IsEE bool ValidSortColumns []string - TeamID uint - KnownExploit bool + TeamID uint `query:"team_id,optional"` + KnownExploit bool `query:"exploit,optional"` +} + +func (opt VulnListOptions) HasValidSortColumn() bool { + if opt.OrderKey == "" || len(opt.ValidSortColumns) == 0 { + return true + } + for _, c := range opt.ValidSortColumns { + if c == opt.OrderKey { + return true + } + } + return false } diff --git a/server/service/handler.go b/server/service/handler.go index d6ea2100ed..5947a61b89 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -371,6 +371,10 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC ue.GET("/api/_version_/fleet/software/titles", listSoftwareTitlesEndpoint, listSoftwareTitlesRequest{}) ue.GET("/api/_version_/fleet/software/titles/{id:[0-9]+}", getSoftwareTitleEndpoint, getSoftwareTitleRequest{}) + // Vulnerabilities + ue.GET("/api/_version_/fleet/vulnerabilities", listVulnerabilitiesEndpoint, listVulnerabilitiesRequest{}) + + // Hosts ue.GET("/api/_version_/fleet/host_summary", getHostSummaryEndpoint, getHostSummaryRequest{}) ue.GET("/api/_version_/fleet/hosts", listHostsEndpoint, listHostsRequest{}) ue.POST("/api/_version_/fleet/hosts/delete", deleteHostsEndpoint, deleteHostsRequest{}) diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 9433253656..40aaa2211e 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -7367,6 +7367,156 @@ func (s *integrationTestSuite) TestGetHostDiskEncryption() { require.Contains(t, errMsg, fleet.ErrMDMNotConfigured.Error()) } +func (s *integrationTestSuite) TestListVulnerabilities() { + t := s.T() + var resp listVulnerabilitiesResponse + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp) + + // Invalid Order Key + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusBadRequest, &resp, "order_key", "foo", "order_direction", "asc") + + // EE Order Key is an invalid order key + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusBadRequest, &resp, "order_key", "cvss_score", "order_direction", "asc") + + // Exploit is an EE only filter + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusPaymentRequired, &resp, "exploit", "true") + + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp) + require.Len(s.T(), resp.Vulnerabilities, 0) + + host, err := s.ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "2"), + OsqueryHostID: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "2"), + UUID: t.Name() + "2", + Hostname: t.Name() + "foo2.local", + PrimaryIP: "192.168.1.2", + PrimaryMac: "30-65-EC-6F-C4-59", + Platform: "windows", + }) + require.NoError(t, err) + + err = s.ds.UpdateHostOperatingSystem(context.Background(), host.ID, fleet.OperatingSystem{ + Name: "windows", + Version: "10.0.19042.1234", + Arch: "64bit", + Platform: "windows", + }) + require.NoError(t, err) + allos, err := s.ds.ListOperatingSystems(context.Background()) + require.NoError(t, err) + var os fleet.OperatingSystem + for _, o := range allos { + if o.ID > os.ID { + os = o + } + } + + _, err = s.ds.InsertOSVulnerability(context.Background(), fleet.OSVulnerability{ + OSID: os.ID, + CVE: "CVE-2021-1234", + }, fleet.MSRCSource) + require.NoError(t, err) + + res, err := s.ds.UpdateHostSoftware(context.Background(), host.ID, []fleet.Software{ + {Name: "foo", Version: "0.0.1", Source: "chrome_extensions"}, + }) + require.NoError(t, err) + sw := res.Inserted[0] + + _, err = s.ds.InsertSoftwareVulnerability(context.Background(), fleet.SoftwareVulnerability{ + SoftwareID: sw.ID, + CVE: "CVE-2021-1235", + }, fleet.NVDSource) + require.NoError(t, err) + + // insert CVEMeta + mockTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC) + err = s.ds.InsertCVEMeta(context.Background(), []fleet.CVEMeta{ + { + CVE: "CVE-2021-1234", + CVSSScore: ptr.Float64(7.5), + EPSSProbability: ptr.Float64(0.5), + CISAKnownExploit: ptr.Bool(true), + Published: ptr.Time(mockTime), + Description: "Test CVE 2021-1234", + }, + { + CVE: "CVE-2021-1235", + CVSSScore: ptr.Float64(5.4), + EPSSProbability: ptr.Float64(0.6), + CISAKnownExploit: ptr.Bool(false), + Published: ptr.Time(mockTime), + Description: "Test CVE 2021-1235", + }, + }) + require.NoError(t, err) + + err = s.ds.UpdateVulnerabilityHostCounts(context.Background()) + require.NoError(t, err) + + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp) + require.Empty(t, resp.Err) + require.Len(s.T(), resp.Vulnerabilities, 2) + require.Equal(t, resp.Count, uint(2)) + require.False(t, resp.Meta.HasPreviousResults) + require.False(t, resp.Meta.HasNextResults) + + expected := map[string]struct { + fleet.CVEMeta + HostCount uint + DetailsLink string + Source fleet.VulnerabilitySource + }{ + "CVE-2021-1234": { + HostCount: 1, + DetailsLink: "https://msrc.microsoft.com/update-guide/en-US/vulnerability/CVE-2021-1234", + }, + "CVE-2021-1235": { + HostCount: 1, + DetailsLink: "https://nvd.nist.gov/vuln/detail/CVE-2021-1235", + }, + } + + for _, vuln := range resp.Vulnerabilities { + expectedVuln, ok := expected[vuln.CVE] + require.True(t, ok) + require.Equal(t, expectedVuln.HostCount, vuln.HostCount) + require.Equal(t, expectedVuln.DetailsLink, vuln.DetailsLink) + require.Empty(t, vuln.CVSSScore) + } + + // Test Team Filter + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp, "team_id", "1") + require.Len(s.T(), resp.Vulnerabilities, 0) + + team, err := s.ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"}) + require.NoError(t, err) + err = s.ds.AddHostsToTeam(context.Background(), &team.ID, []uint{host.ID}) + require.NoError(t, err) + + err = s.ds.UpdateVulnerabilityHostCounts(context.Background()) + require.NoError(t, err) + + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp, "team_id", fmt.Sprintf("%d", team.ID)) + require.Len(t, resp.Vulnerabilities, 2) + require.Equal(t, uint(2), resp.Count) + require.False(t, resp.Meta.HasPreviousResults) + require.False(t, resp.Meta.HasNextResults) + require.Empty(t, resp.Err) + + for _, vuln := range resp.Vulnerabilities { + expectedVuln, ok := expected[vuln.CVE] + require.True(t, ok) + require.Equal(t, expectedVuln.HostCount, vuln.HostCount) + require.Equal(t, expectedVuln.DetailsLink, vuln.DetailsLink) + require.Empty(t, vuln.CVSSScore) + } +} + func (s *integrationTestSuite) TestOSVersions() { t := s.T() diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index 888af14c12..c9b8a7608f 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -3133,6 +3133,174 @@ func (s *integrationEnterpriseTestSuite) TestListHosts() { } } +func (s *integrationEnterpriseTestSuite) TestListVulnerabilities() { + t := s.T() + var resp listVulnerabilitiesResponse + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp) + + // Invalid Order Key + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusBadRequest, &resp, "order_key", "foo", "order_direction", "asc") + + // EE Only Order Key + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp, "order_key", "cvss_score", "order_direction", "asc") + + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp) + require.Len(s.T(), resp.Vulnerabilities, 0) + + host, err := s.ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "2"), + OsqueryHostID: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "2"), + UUID: t.Name() + "2", + Hostname: t.Name() + "foo2.local", + PrimaryIP: "192.168.1.2", + PrimaryMac: "30-65-EC-6F-C4-59", + Platform: "windows", + }) + require.NoError(t, err) + + err = s.ds.UpdateHostOperatingSystem(context.Background(), host.ID, fleet.OperatingSystem{ + Name: "windows", + Version: "10.0.19042.1234", + Arch: "64bit", + Platform: "windows", + }) + require.NoError(t, err) + allos, err := s.ds.ListOperatingSystems(context.Background()) + require.NoError(t, err) + var os fleet.OperatingSystem + for _, o := range allos { + if o.ID > os.ID { + os = o + } + } + + _, err = s.ds.InsertOSVulnerability(context.Background(), fleet.OSVulnerability{ + OSID: os.ID, + CVE: "CVE-2021-1234", + }, fleet.MSRCSource) + require.NoError(t, err) + + res, err := s.ds.UpdateHostSoftware(context.Background(), host.ID, []fleet.Software{ + {Name: "foo", Version: "0.0.1", Source: "chrome_extensions"}, + }) + require.NoError(t, err) + sw := res.Inserted[0] + + _, err = s.ds.InsertSoftwareVulnerability(context.Background(), fleet.SoftwareVulnerability{ + SoftwareID: sw.ID, + CVE: "CVE-2021-1235", + }, fleet.NVDSource) + require.NoError(t, err) + + // insert CVEMeta + mockTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC) + err = s.ds.InsertCVEMeta(context.Background(), []fleet.CVEMeta{ + { + CVE: "CVE-2021-1234", + CVSSScore: ptr.Float64(7.5), + EPSSProbability: ptr.Float64(0.5), + CISAKnownExploit: ptr.Bool(true), + Published: ptr.Time(mockTime), + Description: "Test CVE 2021-1234", + }, + { + CVE: "CVE-2021-1235", + CVSSScore: ptr.Float64(5.4), + EPSSProbability: ptr.Float64(0.6), + CISAKnownExploit: ptr.Bool(false), + Published: ptr.Time(mockTime), + Description: "Test CVE 2021-1235", + }, + }) + require.NoError(t, err) + + err = s.ds.UpdateVulnerabilityHostCounts(context.Background()) + require.NoError(t, err) + + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp) + require.Len(s.T(), resp.Vulnerabilities, 2) + require.Equal(t, resp.Count, uint(2)) + require.False(t, resp.Meta.HasPreviousResults) + require.False(t, resp.Meta.HasNextResults) + require.Empty(t, resp.Err) + + expected := map[string]struct { + fleet.CVEMeta + HostCount uint + DetailsLink string + Source fleet.VulnerabilitySource + }{ + "CVE-2021-1234": { + HostCount: 1, + DetailsLink: "https://msrc.microsoft.com/update-guide/en-US/vulnerability/CVE-2021-1234", + CVEMeta: fleet.CVEMeta{ + CVE: "CVE-2021-1234", + CVSSScore: ptr.Float64(7.5), + EPSSProbability: ptr.Float64(0.5), + CISAKnownExploit: ptr.Bool(true), + Published: ptr.Time(mockTime), + Description: "Test CVE 2021-1234", + }, + }, + "CVE-2021-1235": { + HostCount: 1, + DetailsLink: "https://nvd.nist.gov/vuln/detail/CVE-2021-1235", + CVEMeta: fleet.CVEMeta{ + CVE: "CVE-2021-1235", + CVSSScore: ptr.Float64(5.4), + EPSSProbability: ptr.Float64(0.6), + CISAKnownExploit: ptr.Bool(false), + Published: ptr.Time(mockTime), + Description: "Test CVE 2021-1235", + }, + }, + } + + for _, vuln := range resp.Vulnerabilities { + expectedVuln, ok := expected[vuln.CVE] + require.True(t, ok) + require.Equal(t, expectedVuln.HostCount, vuln.HostCount) + require.Equal(t, expectedVuln.DetailsLink, vuln.DetailsLink) + require.Equal(t, expectedVuln.CVEMeta, vuln.CVEMeta) + } + + // EE Exploit Filter + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp, "exploit", "true") + require.Len(t, resp.Vulnerabilities, 1) + require.Equal(t, "CVE-2021-1234", resp.Vulnerabilities[0].CVE) + + // Test Team Filter + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp, "team_id", "1") + require.Len(s.T(), resp.Vulnerabilities, 0) + + team, err := s.ds.NewTeam(context.Background(), &fleet.Team{Name: "team1"}) + require.NoError(t, err) + err = s.ds.AddHostsToTeam(context.Background(), &team.ID, []uint{host.ID}) + require.NoError(t, err) + + err = s.ds.UpdateVulnerabilityHostCounts(context.Background()) + require.NoError(t, err) + + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp, "team_id", fmt.Sprintf("%d", team.ID)) + require.Len(t, resp.Vulnerabilities, 2) + require.Equal(t, uint(2), resp.Count) + require.False(t, resp.Meta.HasPreviousResults) + require.False(t, resp.Meta.HasNextResults) + require.Empty(t, resp.Err) + + for _, vuln := range resp.Vulnerabilities { + expectedVuln, ok := expected[vuln.CVE] + require.True(t, ok) + require.Equal(t, expectedVuln.HostCount, vuln.HostCount) + require.Equal(t, expectedVuln.DetailsLink, vuln.DetailsLink) + require.Equal(t, expectedVuln.CVEMeta, vuln.CVEMeta) + } +} + func (s *integrationEnterpriseTestSuite) TestOSVersions() { t := s.T() @@ -3146,13 +3314,16 @@ func (s *integrationEnterpriseTestSuite) TestOSVersions() { // set operating system information on a host require.NoError(t, s.ds.UpdateHostOperatingSystem(context.Background(), hosts[0].ID, testOS)) - var osID uint + var osinfo struct { + ID uint `db:"id"` + OSVersionID uint `db:"os_version_id"` + } mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error { - return sqlx.GetContext(context.Background(), q, &osID, - `SELECT id FROM operating_systems WHERE name = ? AND version = ? AND arch = ? AND kernel_version = ? AND platform = ?`, + return sqlx.GetContext(context.Background(), q, &osinfo, + `SELECT id, os_version_id FROM operating_systems WHERE name = ? AND version = ? AND arch = ? AND kernel_version = ? AND platform = ?`, testOS.Name, testOS.Version, testOS.Arch, testOS.KernelVersion, testOS.Platform) }) - require.Greater(t, osID, uint(0)) + require.Greater(t, osinfo.ID, uint(0)) resp = listHostsResponse{} s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "os_name", testOS.Name, "os_version", testOS.Version) @@ -3160,7 +3331,7 @@ func (s *integrationEnterpriseTestSuite) TestOSVersions() { expected := resp.Hosts[0] resp = listHostsResponse{} - s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "os_id", fmt.Sprintf("%d", osID)) + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "os_id", fmt.Sprintf("%d", osinfo.ID)) require.Len(t, resp.Hosts, 1) require.Equal(t, expected, resp.Hosts[0]) @@ -3169,7 +3340,7 @@ func (s *integrationEnterpriseTestSuite) TestOSVersions() { // insert OS Vulns _, err := s.ds.InsertOSVulnerability(context.Background(), fleet.OSVulnerability{ - OSID: osID, + OSID: osinfo.ID, CVE: "CVE-2021-1234", }, fleet.MSRCSource) require.NoError(t, err) @@ -3205,7 +3376,7 @@ func (s *integrationEnterpriseTestSuite) TestOSVersions() { require.Equal(t, vulnMeta[0].Description, **osVersionsResp.OSVersions[0].Vulnerabilities[0].Description) var osVersionResp getOSVersionResponse - s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", 1), nil, http.StatusOK, &osVersionResp) + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/os_versions/%d", osinfo.OSVersionID), nil, http.StatusOK, &osVersionResp) require.Equal(t, &osVersionsResp.OSVersions[0], osVersionResp.OSVersion) // return empty json if UpdateOSVersions cron hasn't run yet for new team diff --git a/server/service/vulnerabilities.go b/server/service/vulnerabilities.go new file mode 100644 index 0000000000..7e620144cb --- /dev/null +++ b/server/service/vulnerabilities.go @@ -0,0 +1,92 @@ +package service + +import ( + "context" + "fmt" + + "github.com/fleetdm/fleet/v4/server/fleet" +) + +var freeValidVulnSortColumns = []string{ + "cve", + "host_count", + "host_count_updated_at", + "created_at", +} + +type listVulnerabilitiesRequest struct { + fleet.VulnListOptions +} + +type listVulnerabilitiesResponse struct { + Vulnerabilities []fleet.VulnerabilityWithMetadata `json:"vulnerabilities"` + Count uint `json:"count"` + Meta *fleet.PaginationMetadata `json:"meta,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r listVulnerabilitiesResponse) error() error { return r.Err } + +func listVulnerabilitiesEndpoint(ctx context.Context, req interface{}, svc fleet.Service) (errorer, error) { + request := req.(*listVulnerabilitiesRequest) + vulns, meta, err := svc.ListVulnerabilities(ctx, request.VulnListOptions) + if err != nil { + return listVulnerabilitiesResponse{Err: err}, nil + } + + count, err := svc.CountVulnerabilities(ctx, request.VulnListOptions) + if err != nil { + return listVulnerabilitiesResponse{Err: err}, nil + } + + return listVulnerabilitiesResponse{ + Vulnerabilities: vulns, + Meta: meta, + Count: count, + }, nil +} + +func (svc *Service) ListVulnerabilities(ctx context.Context, opt fleet.VulnListOptions) ([]fleet.VulnerabilityWithMetadata, *fleet.PaginationMetadata, error) { + if err := svc.authz.Authorize(ctx, &fleet.AuthzSoftwareInventory{ + TeamID: &opt.TeamID, + }, fleet.ActionRead); err != nil { + return nil, nil, err + } + + if len(opt.ValidSortColumns) == 0 { + opt.ValidSortColumns = freeValidVulnSortColumns + } + + if !opt.HasValidSortColumn() { + return nil, nil, badRequest("invalid order key") + } + + if opt.KnownExploit && !opt.IsEE { + return nil, nil, fleet.ErrMissingLicense + } + + vulns, meta, err := svc.ds.ListVulnerabilities(ctx, opt) + if err != nil { + return nil, nil, err + } + + for i, vuln := range vulns { + if vuln.Source == fleet.MSRCSource { + vulns[i].DetailsLink = fmt.Sprintf("https://msrc.microsoft.com/update-guide/en-US/vulnerability/%s", vuln.CVE) + } else { + vulns[i].DetailsLink = fmt.Sprintf("https://nvd.nist.gov/vuln/detail/%s", vuln.CVE) + } + } + + return vulns, meta, nil +} + +func (svc *Service) CountVulnerabilities(ctx context.Context, opts fleet.VulnListOptions) (uint, error) { + if err := svc.authz.Authorize(ctx, &fleet.AuthzSoftwareInventory{ + TeamID: &opts.TeamID, + }, fleet.ActionRead); err != nil { + return 0, err + } + + return svc.ds.CountVulnerabilities(ctx, opts) +} diff --git a/server/service/vulnerabilities_test.go b/server/service/vulnerabilities_test.go new file mode 100644 index 0000000000..95bc20bb94 --- /dev/null +++ b/server/service/vulnerabilities_test.go @@ -0,0 +1,53 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/fleetdm/fleet/v4/server/contexts/viewer" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/mock" + "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/stretchr/testify/require" +) + +func TestListVulnerabilities(t *testing.T) { + ds := new(mock.Store) + svc, ctx := newTestService(t, ds, nil, nil) + ctx = viewer.NewContext(ctx, viewer.Viewer{User: &fleet.User{GlobalRole: ptr.String(fleet.RoleAdmin)}}) + + ds.ListVulnerabilitiesFunc = func(cxt context.Context, opt fleet.VulnListOptions) ([]fleet.VulnerabilityWithMetadata, *fleet.PaginationMetadata, error) { + return []fleet.VulnerabilityWithMetadata{ + { + CVEMeta: fleet.CVEMeta{ + CVE: "CVE-2019-1234", + Description: "A vulnerability", + }, + CreatedAt: time.Now(), + HostCount: 10, + }, + }, nil, nil + } + + t.Run("no list options", func(t *testing.T) { + _, _, err := svc.ListVulnerabilities(ctx, fleet.VulnListOptions{}) + require.NoError(t, err) + }) + + t.Run("can only sort by supported columns", func(t *testing.T) { + // invalid order key + opts := fleet.VulnListOptions{ListOptions: fleet.ListOptions{ + OrderKey: "invalid", + }, ValidSortColumns: freeValidVulnSortColumns} + + _, _, err := svc.ListVulnerabilities(ctx, opts) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid order key") + + // valid order key + opts.OrderKey = "cve" + _, _, err = svc.ListVulnerabilities(ctx, opts) + require.NoError(t, err) + }) +}