From 8cb6722df8510de8a5a8ae1461dd7ef533d555e5 Mon Sep 17 00:00:00 2001 From: Tim Lee Date: Tue, 20 Feb 2024 08:49:11 -0700 Subject: [PATCH] Return 0 count for team vulnerability (#16897) #16891 - [X] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [X] Added/updated tests - [X] Manual QA for all new/changed functionality --------- Co-authored-by: Victor Lyuboslavsky Co-authored-by: Victor Lyuboslavsky --- server/datastore/mysql/vulnerabilities.go | 89 ++++++++++++++--------- server/service/integration_core_test.go | 68 ++++++++++++++--- server/service/vulnerabilities.go | 11 +++ server/service/vulnerabilities_test.go | 4 + 4 files changed, 130 insertions(+), 42 deletions(-) diff --git a/server/datastore/mysql/vulnerabilities.go b/server/datastore/mysql/vulnerabilities.go index 2dbfa0e771..237c3f5624 100644 --- a/server/datastore/mysql/vulnerabilities.go +++ b/server/datastore/mysql/vulnerabilities.go @@ -16,44 +16,59 @@ func (ds *Datastore) Vulnerability(ctx context.Context, cve string, teamID *uint var vuln fleet.VulnerabilityWithMetadata eeSelectStmt := ` - SELECT - vhc.cve, - MIN(COALESCE(osv.created_at, sc.created_at, NOW())) AS created_at, + SELECT DISTINCT + cm.cve, + COALESCE(LEAST(osv.created_at, sc.created_at), NOW()) AS created_at, COALESCE(osv.source, sc.source, 0) AS source, cm.cvss_score, cm.epss_probability, cm.cisa_known_exploit, cm.published, - COALESCE(cm.description, '') AS description, - vhc.host_count, - vhc.updated_at as host_count_updated_at - FROM - vulnerability_host_counts vhc - LEFT JOIN cve_meta cm ON cm.cve = vhc.cve - LEFT JOIN operating_system_vulnerabilities osv ON osv.cve = vhc.cve - LEFT JOIN software_cve sc ON sc.cve = vhc.cve - WHERE vhc.cve = ? - ` - eeGroupBy := " GROUP BY vhc.cve, source, cm.cvss_score, cm.epss_probability, cm.cisa_known_exploit, cm.published, description, vhc.host_count, host_count_updated_at" + cm.description, + COALESCE(vhc.host_count, 0) as host_count, + COALESCE(vhc.updated_at, NOW()) as host_count_updated_at + FROM cve_meta cm + JOIN ( + SELECT cve + FROM software_cve + WHERE cve = ? + + UNION + + SELECT cve + FROM operating_system_vulnerabilities + WHERE cve = ? + ) AS cve_table ON cm.cve = cve_table.cve + LEFT JOIN operating_system_vulnerabilities osv ON osv.cve = cm.cve + LEFT JOIN software_cve sc ON sc.cve = cm.cve + LEFT JOIN vulnerability_host_counts vhc ON cm.cve = vhc.cve +` freeSelectStmt := ` - SELECT - vhc.cve, - MIN(COALESCE(osv.created_at, sc.created_at, NOW())) AS created_at, + SELECT DISTINCT + union_cve.cve, + COALESCE(LEAST(osv.created_at, sc.created_at), NOW()) AS created_at, COALESCE(osv.source, sc.source, 0) AS source, - vhc.host_count, - vhc.updated_at as host_count_updated_at - FROM - vulnerability_host_counts vhc - LEFT JOIN operating_system_vulnerabilities osv ON osv.cve = vhc.cve - LEFT JOIN software_cve sc ON sc.cve = vhc.cve - WHERE vhc.cve = ? + COALESCE(vhc.host_count, 0) as host_count, + COALESCE(vhc.updated_at, NOW()) as host_count_updated_at + FROM ( + SELECT cve, created_at, source + FROM operating_system_vulnerabilities + WHERE cve = ? + + UNION + + SELECT cve, created_at, source + FROM software_cve + WHERE cve = ? + ) AS union_cve + LEFT JOIN operating_system_vulnerabilities osv ON osv.cve = union_cve.cve + LEFT JOIN software_cve sc ON sc.cve = union_cve.cve + LEFT JOIN vulnerability_host_counts vhc ON vhc.cve = union_cve.cve ` - freeGroupBy := " GROUP BY vhc.cve, source, vhc.host_count, host_count_updated_at" - var args []interface{} - args = append(args, cve) + args = append(args, cve, cve) if teamID != nil { eeSelectStmt += " AND vhc.team_id = ?" @@ -64,9 +79,6 @@ func (ds *Datastore) Vulnerability(ctx context.Context, cve string, teamID *uint freeSelectStmt += " AND vhc.team_id = 0" } - eeSelectStmt += eeGroupBy - freeSelectStmt += freeGroupBy - var selectStmt string if includeCVEScores { selectStmt = eeSelectStmt @@ -81,13 +93,24 @@ func (ds *Datastore) Vulnerability(ctx context.Context, cve string, teamID *uint } return nil, ctxerr.Wrap(ctx, err, "fetching vulnerability") } + + if vuln.HostCount == 0 { + var msg string + if teamID == nil { + msg = "global" + } else { + msg = fmt.Sprintf("team %d", *teamID) + } + return nil, ctxerr.Wrap(ctx, notFound(fmt.Sprintf("Vulnerability for %s", msg)).WithName(cve)) + } + return &vuln, nil } func (ds *Datastore) OSVersionsByCVE(ctx context.Context, cve string, teamID *uint) (vos []*fleet.VulnerableOS, updatedAt time.Time, err error) { osvs, err := ds.OSVersions(ctx, teamID, nil, nil, nil) - if err != nil { - return nil, updatedAt, ctxerr.Wrap(ctx, err, "fetching OS versions by CVE") + if err != nil && !fleet.IsNotFound(err) { + return nil, updatedAt, ctxerr.Wrap(ctx, err, "fetching team OS versions") } updatedAt = osvs.CountsUpdatedAt @@ -108,7 +131,7 @@ func (ds *Datastore) OSVersionsByCVE(ctx context.Context, cve string, teamID *ui if err == sql.ErrNoRows { return nil, updatedAt, ctxerr.Wrap(ctx, notFound("Vulnerability").WithName(cve)) } - return vos, updatedAt, ctxerr.Wrap(ctx, err, "fetching OS versions by CVE") + return vos, updatedAt, ctxerr.Wrap(ctx, err, "fetching OS version and resolved version by CVE") } for _, osv := range osvs.OSVersions { diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index dd55254a7a..10df9624c8 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -7425,10 +7425,10 @@ func (s *integrationTestSuite) TestListVulnerabilities() { LabelUpdatedAt: time.Now(), PolicyUpdatedAt: time.Now(), SeenTime: time.Now(), - NodeKey: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "2"), - OsqueryHostID: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "2"), - UUID: t.Name() + "2", - Hostname: t.Name() + "foo2.local", + NodeKey: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "1"), + OsqueryHostID: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "1"), + UUID: t.Name() + "1", + Hostname: t.Name() + "foo1.local", PrimaryIP: "192.168.1.2", PrimaryMac: "30-65-EC-6F-C4-59", Platform: "windows", @@ -7474,15 +7474,43 @@ func (s *integrationTestSuite) TestListVulnerabilities() { }) require.NoError(t, err) - err = s.ds.SyncHostsSoftware(context.Background(), time.Now()) - require.NoError(t, err) - _, err = s.ds.InsertSoftwareVulnerability(context.Background(), fleet.SoftwareVulnerability{ SoftwareID: sw.ID, CVE: "CVE-2021-1235", }, fleet.NVDSource) require.NoError(t, err) + err = s.ds.SyncHostsSoftware(context.Background(), time.Now()) + require.NoError(t, err) + + host2, err := s.ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "2"), + OsqueryHostID: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "2"), + UUID: t.Name() + "2", + Hostname: t.Name() + "foo2.local", + PrimaryIP: "192.168.1.2", + PrimaryMac: "30-65-EC-6F-C4-59", + Platform: "windows", + }) + require.NoError(t, err) + + res2, err := s.ds.UpdateHostSoftware(context.Background(), host2.ID, []fleet.Software{ + {Name: "Firefox", Version: "0.0.1", Source: "programs"}, + }) + require.NoError(t, err) + sw2 := res2.Inserted[0] + + // insert software vuln outside of host scope + _, err = s.ds.InsertSoftwareVulnerability(context.Background(), fleet.SoftwareVulnerability{ + SoftwareID: sw2.ID, + CVE: "CVE-2021-1236", + }, fleet.NVDSource) + require.NoError(t, err) + // insert CVEMeta mockTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC) err = s.ds.InsertCVEMeta(context.Background(), []fleet.CVEMeta{ @@ -7502,6 +7530,14 @@ func (s *integrationTestSuite) TestListVulnerabilities() { Published: ptr.Time(mockTime), Description: "Test CVE 2021-1235", }, + { + CVE: "CVE-2021-1236", + CVSSScore: ptr.Float64(5.4), + EPSSProbability: ptr.Float64(0.6), + CISAKnownExploit: ptr.Bool(false), + Published: ptr.Time(mockTime), + Description: "Test CVE 2021-1236", + }, }) require.NoError(t, err) @@ -7510,8 +7546,8 @@ func (s *integrationTestSuite) TestListVulnerabilities() { s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp) require.Empty(t, resp.Err) - require.Len(s.T(), resp.Vulnerabilities, 2) - require.Equal(t, resp.Count, uint(2)) + require.Len(s.T(), resp.Vulnerabilities, 3) + require.Equal(t, resp.Count, uint(3)) require.False(t, resp.Meta.HasPreviousResults) require.False(t, resp.Meta.HasNextResults) @@ -7529,6 +7565,10 @@ func (s *integrationTestSuite) TestListVulnerabilities() { HostCount: 1, DetailsLink: "https://nvd.nist.gov/vuln/detail/CVE-2021-1235", }, + "CVE-2021-1236": { + HostCount: 1, + DetailsLink: "https://nvd.nist.gov/vuln/detail/CVE-2021-1236", + }, } for _, vuln := range resp.Vulnerabilities { @@ -7567,6 +7607,16 @@ func (s *integrationTestSuite) TestListVulnerabilities() { } var gResp getVulnerabilityResponse + // invalid cve + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities/foobar", nil, http.StatusNotFound, &gResp) + + // Valid CVE but not in team scope + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities/CVE-2021-1236", nil, http.StatusNotFound, &gResp, "team_id", fmt.Sprintf("%d", team.ID)) + + // Invalid TeamID + s.DoJSON("GET", "/api/latest/fleet/vulnerabilities/CVE-2021-1234", nil, http.StatusForbidden, &gResp, "team_id", "100") + + // Valid Global Request s.DoJSON("GET", "/api/latest/fleet/vulnerabilities/CVE-2021-1234", nil, http.StatusOK, &gResp) require.Empty(t, gResp.Err) require.Equal(t, "CVE-2021-1234", gResp.Vulnerability.CVE) diff --git a/server/service/vulnerabilities.go b/server/service/vulnerabilities.go index 671c140256..82bf22bcd1 100644 --- a/server/service/vulnerabilities.go +++ b/server/service/vulnerabilities.go @@ -5,6 +5,8 @@ import ( "fmt" "time" + "github.com/fleetdm/fleet/v4/server/authz" + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" ) @@ -146,6 +148,15 @@ func (svc *Service) Vulnerability(ctx context.Context, cve string, teamID *uint, return nil, err } + if teamID != nil { + exists, err := svc.ds.TeamExists(ctx, *teamID) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "checking if team exists") + } else if !exists { + return nil, authz.ForbiddenWithInternal("team does not exist", nil, nil, nil) + } + } + vuln, err := svc.ds.Vulnerability(ctx, cve, teamID, useCVSScores) if err != nil { return nil, err diff --git a/server/service/vulnerabilities_test.go b/server/service/vulnerabilities_test.go index 967297531a..7c80d9b77a 100644 --- a/server/service/vulnerabilities_test.go +++ b/server/service/vulnerabilities_test.go @@ -69,6 +69,10 @@ func TestVulnerabilitesAuth(t *testing.T) { return 0, nil } + ds.TeamExistsFunc = func(cxt context.Context, teamID uint) (bool, error) { + return true, nil + } + for _, tc := range []struct { name string user *fleet.User