Return 0 count for team vulnerability (#16897)

#16891 

- [X] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [X] Added/updated tests
- [X] Manual QA for all new/changed functionality

---------

Co-authored-by: Victor Lyuboslavsky <victor@fleetdm.com>
Co-authored-by: Victor Lyuboslavsky <victor.lyuboslavsky@gmail.com>
This commit is contained in:
Tim Lee 2024-02-20 08:49:11 -07:00 committed by GitHub
parent 35ca4ee32b
commit 8cb6722df8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 130 additions and 42 deletions

View file

@ -16,44 +16,59 @@ func (ds *Datastore) Vulnerability(ctx context.Context, cve string, teamID *uint
var vuln fleet.VulnerabilityWithMetadata
eeSelectStmt := `
SELECT
vhc.cve,
MIN(COALESCE(osv.created_at, sc.created_at, NOW())) AS created_at,
SELECT DISTINCT
cm.cve,
COALESCE(LEAST(osv.created_at, sc.created_at), NOW()) AS created_at,
COALESCE(osv.source, sc.source, 0) AS source,
cm.cvss_score,
cm.epss_probability,
cm.cisa_known_exploit,
cm.published,
COALESCE(cm.description, '') AS description,
vhc.host_count,
vhc.updated_at as host_count_updated_at
FROM
vulnerability_host_counts vhc
LEFT JOIN cve_meta cm ON cm.cve = vhc.cve
LEFT JOIN operating_system_vulnerabilities osv ON osv.cve = vhc.cve
LEFT JOIN software_cve sc ON sc.cve = vhc.cve
WHERE vhc.cve = ?
`
eeGroupBy := " GROUP BY vhc.cve, source, cm.cvss_score, cm.epss_probability, cm.cisa_known_exploit, cm.published, description, vhc.host_count, host_count_updated_at"
cm.description,
COALESCE(vhc.host_count, 0) as host_count,
COALESCE(vhc.updated_at, NOW()) as host_count_updated_at
FROM cve_meta cm
JOIN (
SELECT cve
FROM software_cve
WHERE cve = ?
UNION
SELECT cve
FROM operating_system_vulnerabilities
WHERE cve = ?
) AS cve_table ON cm.cve = cve_table.cve
LEFT JOIN operating_system_vulnerabilities osv ON osv.cve = cm.cve
LEFT JOIN software_cve sc ON sc.cve = cm.cve
LEFT JOIN vulnerability_host_counts vhc ON cm.cve = vhc.cve
`
freeSelectStmt := `
SELECT
vhc.cve,
MIN(COALESCE(osv.created_at, sc.created_at, NOW())) AS created_at,
SELECT DISTINCT
union_cve.cve,
COALESCE(LEAST(osv.created_at, sc.created_at), NOW()) AS created_at,
COALESCE(osv.source, sc.source, 0) AS source,
vhc.host_count,
vhc.updated_at as host_count_updated_at
FROM
vulnerability_host_counts vhc
LEFT JOIN operating_system_vulnerabilities osv ON osv.cve = vhc.cve
LEFT JOIN software_cve sc ON sc.cve = vhc.cve
WHERE vhc.cve = ?
COALESCE(vhc.host_count, 0) as host_count,
COALESCE(vhc.updated_at, NOW()) as host_count_updated_at
FROM (
SELECT cve, created_at, source
FROM operating_system_vulnerabilities
WHERE cve = ?
UNION
SELECT cve, created_at, source
FROM software_cve
WHERE cve = ?
) AS union_cve
LEFT JOIN operating_system_vulnerabilities osv ON osv.cve = union_cve.cve
LEFT JOIN software_cve sc ON sc.cve = union_cve.cve
LEFT JOIN vulnerability_host_counts vhc ON vhc.cve = union_cve.cve
`
freeGroupBy := " GROUP BY vhc.cve, source, vhc.host_count, host_count_updated_at"
var args []interface{}
args = append(args, cve)
args = append(args, cve, cve)
if teamID != nil {
eeSelectStmt += " AND vhc.team_id = ?"
@ -64,9 +79,6 @@ func (ds *Datastore) Vulnerability(ctx context.Context, cve string, teamID *uint
freeSelectStmt += " AND vhc.team_id = 0"
}
eeSelectStmt += eeGroupBy
freeSelectStmt += freeGroupBy
var selectStmt string
if includeCVEScores {
selectStmt = eeSelectStmt
@ -81,13 +93,24 @@ func (ds *Datastore) Vulnerability(ctx context.Context, cve string, teamID *uint
}
return nil, ctxerr.Wrap(ctx, err, "fetching vulnerability")
}
if vuln.HostCount == 0 {
var msg string
if teamID == nil {
msg = "global"
} else {
msg = fmt.Sprintf("team %d", *teamID)
}
return nil, ctxerr.Wrap(ctx, notFound(fmt.Sprintf("Vulnerability for %s", msg)).WithName(cve))
}
return &vuln, nil
}
func (ds *Datastore) OSVersionsByCVE(ctx context.Context, cve string, teamID *uint) (vos []*fleet.VulnerableOS, updatedAt time.Time, err error) {
osvs, err := ds.OSVersions(ctx, teamID, nil, nil, nil)
if err != nil {
return nil, updatedAt, ctxerr.Wrap(ctx, err, "fetching OS versions by CVE")
if err != nil && !fleet.IsNotFound(err) {
return nil, updatedAt, ctxerr.Wrap(ctx, err, "fetching team OS versions")
}
updatedAt = osvs.CountsUpdatedAt
@ -108,7 +131,7 @@ func (ds *Datastore) OSVersionsByCVE(ctx context.Context, cve string, teamID *ui
if err == sql.ErrNoRows {
return nil, updatedAt, ctxerr.Wrap(ctx, notFound("Vulnerability").WithName(cve))
}
return vos, updatedAt, ctxerr.Wrap(ctx, err, "fetching OS versions by CVE")
return vos, updatedAt, ctxerr.Wrap(ctx, err, "fetching OS version and resolved version by CVE")
}
for _, osv := range osvs.OSVersions {

View file

@ -7425,10 +7425,10 @@ func (s *integrationTestSuite) TestListVulnerabilities() {
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "2"),
OsqueryHostID: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "2"),
UUID: t.Name() + "2",
Hostname: t.Name() + "foo2.local",
NodeKey: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "1"),
OsqueryHostID: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "1"),
UUID: t.Name() + "1",
Hostname: t.Name() + "foo1.local",
PrimaryIP: "192.168.1.2",
PrimaryMac: "30-65-EC-6F-C4-59",
Platform: "windows",
@ -7474,15 +7474,43 @@ func (s *integrationTestSuite) TestListVulnerabilities() {
})
require.NoError(t, err)
err = s.ds.SyncHostsSoftware(context.Background(), time.Now())
require.NoError(t, err)
_, err = s.ds.InsertSoftwareVulnerability(context.Background(), fleet.SoftwareVulnerability{
SoftwareID: sw.ID,
CVE: "CVE-2021-1235",
}, fleet.NVDSource)
require.NoError(t, err)
err = s.ds.SyncHostsSoftware(context.Background(), time.Now())
require.NoError(t, err)
host2, err := s.ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "2"),
OsqueryHostID: ptr.String(strings.ReplaceAll(t.Name(), "/", "_") + "2"),
UUID: t.Name() + "2",
Hostname: t.Name() + "foo2.local",
PrimaryIP: "192.168.1.2",
PrimaryMac: "30-65-EC-6F-C4-59",
Platform: "windows",
})
require.NoError(t, err)
res2, err := s.ds.UpdateHostSoftware(context.Background(), host2.ID, []fleet.Software{
{Name: "Firefox", Version: "0.0.1", Source: "programs"},
})
require.NoError(t, err)
sw2 := res2.Inserted[0]
// insert software vuln outside of host scope
_, err = s.ds.InsertSoftwareVulnerability(context.Background(), fleet.SoftwareVulnerability{
SoftwareID: sw2.ID,
CVE: "CVE-2021-1236",
}, fleet.NVDSource)
require.NoError(t, err)
// insert CVEMeta
mockTime := time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)
err = s.ds.InsertCVEMeta(context.Background(), []fleet.CVEMeta{
@ -7502,6 +7530,14 @@ func (s *integrationTestSuite) TestListVulnerabilities() {
Published: ptr.Time(mockTime),
Description: "Test CVE 2021-1235",
},
{
CVE: "CVE-2021-1236",
CVSSScore: ptr.Float64(5.4),
EPSSProbability: ptr.Float64(0.6),
CISAKnownExploit: ptr.Bool(false),
Published: ptr.Time(mockTime),
Description: "Test CVE 2021-1236",
},
})
require.NoError(t, err)
@ -7510,8 +7546,8 @@ func (s *integrationTestSuite) TestListVulnerabilities() {
s.DoJSON("GET", "/api/latest/fleet/vulnerabilities", nil, http.StatusOK, &resp)
require.Empty(t, resp.Err)
require.Len(s.T(), resp.Vulnerabilities, 2)
require.Equal(t, resp.Count, uint(2))
require.Len(s.T(), resp.Vulnerabilities, 3)
require.Equal(t, resp.Count, uint(3))
require.False(t, resp.Meta.HasPreviousResults)
require.False(t, resp.Meta.HasNextResults)
@ -7529,6 +7565,10 @@ func (s *integrationTestSuite) TestListVulnerabilities() {
HostCount: 1,
DetailsLink: "https://nvd.nist.gov/vuln/detail/CVE-2021-1235",
},
"CVE-2021-1236": {
HostCount: 1,
DetailsLink: "https://nvd.nist.gov/vuln/detail/CVE-2021-1236",
},
}
for _, vuln := range resp.Vulnerabilities {
@ -7567,6 +7607,16 @@ func (s *integrationTestSuite) TestListVulnerabilities() {
}
var gResp getVulnerabilityResponse
// invalid cve
s.DoJSON("GET", "/api/latest/fleet/vulnerabilities/foobar", nil, http.StatusNotFound, &gResp)
// Valid CVE but not in team scope
s.DoJSON("GET", "/api/latest/fleet/vulnerabilities/CVE-2021-1236", nil, http.StatusNotFound, &gResp, "team_id", fmt.Sprintf("%d", team.ID))
// Invalid TeamID
s.DoJSON("GET", "/api/latest/fleet/vulnerabilities/CVE-2021-1234", nil, http.StatusForbidden, &gResp, "team_id", "100")
// Valid Global Request
s.DoJSON("GET", "/api/latest/fleet/vulnerabilities/CVE-2021-1234", nil, http.StatusOK, &gResp)
require.Empty(t, gResp.Err)
require.Equal(t, "CVE-2021-1234", gResp.Vulnerability.CVE)

View file

@ -5,6 +5,8 @@ import (
"fmt"
"time"
"github.com/fleetdm/fleet/v4/server/authz"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
)
@ -146,6 +148,15 @@ func (svc *Service) Vulnerability(ctx context.Context, cve string, teamID *uint,
return nil, err
}
if teamID != nil {
exists, err := svc.ds.TeamExists(ctx, *teamID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "checking if team exists")
} else if !exists {
return nil, authz.ForbiddenWithInternal("team does not exist", nil, nil, nil)
}
}
vuln, err := svc.ds.Vulnerability(ctx, cve, teamID, useCVSScores)
if err != nil {
return nil, err

View file

@ -69,6 +69,10 @@ func TestVulnerabilitesAuth(t *testing.T) {
return 0, nil
}
ds.TeamExistsFunc = func(cxt context.Context, teamID uint) (bool, error) {
return true, nil
}
for _, tc := range []struct {
name string
user *fleet.User