diff --git a/server/datastore/mysql/vulnerabilities.go b/server/datastore/mysql/vulnerabilities.go index 0247ef47b1..97542dd7ed 100644 --- a/server/datastore/mysql/vulnerabilities.go +++ b/server/datastore/mysql/vulnerabilities.go @@ -190,9 +190,9 @@ func (ds *Datastore) ListVulnerabilities(ctx context.Context, opt fleet.VulnList // Define base select statements for EE and Free versions eeSelectStmt := ` SELECT - vhc.cve, - MIN(COALESCE(osv.created_at, sc.created_at, NOW())) AS created_at, - COALESCE(osv.source, sc.source, 0) AS source, + combined.cve as cve, + MIN(combined.created_at) as created_at, + MIN(combined.source) as source, cm.cvss_score, cm.epss_probability, cm.cisa_known_exploit, @@ -200,24 +200,28 @@ func (ds *Datastore) ListVulnerabilities(ctx context.Context, opt fleet.VulnList cm.description, vhc.host_count as hosts_count, vhc.updated_at as hosts_count_updated_at - FROM - vulnerability_host_counts vhc - LEFT JOIN cve_meta cm ON cm.cve = vhc.cve - LEFT JOIN operating_system_vulnerabilities osv ON osv.cve = vhc.cve - LEFT JOIN software_cve sc ON sc.cve = vhc.cve + FROM ( + SELECT cve, created_at, source FROM software_cve + UNION + SELECT cve, created_at, source FROM operating_system_vulnerabilities + ) AS combined + INNER JOIN vulnerability_host_counts vhc ON vhc.cve = combined.cve + LEFT JOIN cve_meta cm ON cm.cve = combined.cve WHERE vhc.host_count > 0 ` freeSelectStmt := ` SELECT - vhc.cve, - MIN(COALESCE(osv.created_at, sc.created_at, NOW())) AS created_at, - COALESCE(osv.source, sc.source, 0) AS source, + combined.cve as cve, + MIN(combined.created_at) as created_at, + MIN(combined.source) as source, vhc.host_count as hosts_count, vhc.updated_at as hosts_count_updated_at - FROM - vulnerability_host_counts vhc - LEFT JOIN operating_system_vulnerabilities osv ON osv.cve = vhc.cve - LEFT JOIN software_cve sc ON sc.cve = vhc.cve + FROM ( + SELECT cve, created_at, source FROM software_cve + UNION + SELECT cve, created_at, source FROM operating_system_vulnerabilities + ) AS combined + INNER JOIN vulnerability_host_counts vhc ON vhc.cve = combined.cve WHERE vhc.host_count > 0 ` @@ -229,28 +233,6 @@ func (ds *Datastore) ListVulnerabilities(ctx context.Context, opt fleet.VulnList selectStmt = freeSelectStmt } - // Define group by statements for EE and Free - eeGroupBy := ` GROUP BY - vhc.cve, - source, - cm.cvss_score, - cm.epss_probability, - cm.cisa_known_exploit, - cve_published, - description, - hosts_count, - hosts_count_updated_at - ` - freeGroupBy := " GROUP BY vhc.cve, source, hosts_count, hosts_count_updated_at" - - // Choose the appropriate group by statement based on EE or Free - var groupBy string - if opt.IsEE { - groupBy = eeGroupBy - } else { - groupBy = freeGroupBy - } - // Prepare arguments for the query var args []interface{} if opt.TeamID == 0 { @@ -269,7 +251,7 @@ func (ds *Datastore) ListVulnerabilities(ctx context.Context, opt fleet.VulnList } // Append group by statement - selectStmt += groupBy + selectStmt += " GROUP BY cve, host_count, updated_at" opt.ListOptions.IncludeMetadata = !(opt.ListOptions.UsesCursorPagination()) selectStmt, args = appendListOptionsWithCursorToSQL(selectStmt, args, &opt.ListOptions) @@ -295,13 +277,17 @@ func (ds *Datastore) ListVulnerabilities(ctx context.Context, opt fleet.VulnList func (ds *Datastore) CountVulnerabilities(ctx context.Context, opt fleet.VulnListOptions) (uint, error) { selectStmt := ` - SELECT COUNT(*) - FROM vulnerability_host_counts vhc - LEFT JOIN cve_meta cm ON cm.cve = vhc.cve - LEFT JOIN operating_system_vulnerabilities osv ON osv.cve = vhc.cve - LEFT JOIN software_cve sc ON sc.cve = vhc.cve + SELECT + COUNT(*) + FROM ( + SELECT cve, created_at, source FROM software_cve + UNION + SELECT cve, created_at, source FROM operating_system_vulnerabilities + ) AS combined + INNER JOIN vulnerability_host_counts vhc ON vhc.cve = combined.cve + LEFT JOIN cve_meta cm ON cm.cve = combined.cve WHERE vhc.host_count > 0 - ` + ` var args []interface{} if opt.TeamID == 0 { selectStmt = selectStmt + " AND vhc.team_id = 0" diff --git a/server/datastore/mysql/vulnerabilities_test.go b/server/datastore/mysql/vulnerabilities_test.go index c633a69b50..29236f4715 100644 --- a/server/datastore/mysql/vulnerabilities_test.go +++ b/server/datastore/mysql/vulnerabilities_test.go @@ -60,9 +60,10 @@ func testListVulnerabilities(t *testing.T, ds *Datastore) { _, err = ds.writer(context.Background()).Exec(insertStmt, "CVE-2020-1236", 0, 20) require.NoError(t, err) + // No Vulns unless OS or Software Vulns are inserted list, _, err = ds.ListVulnerabilities(context.Background(), opts) require.NoError(t, err) - require.Len(t, list, 3) + require.Len(t, list, 0) // insert OS Vuln _, err = ds.InsertOSVulnerabilities(context.Background(), []fleet.OSVulnerability{ @@ -873,9 +874,11 @@ func testSoftwareByCVE(t *testing.T, ds *Datastore) { func assertHostCounts(t *testing.T, expected []hostCount, actual []fleet.VulnerabilityWithMetadata) { t.Helper() require.Len(t, actual, len(expected)) - for i, vuln := range actual { - require.Equal(t, expected[i].CVE, vuln.CVE.CVE) - require.Equal(t, expected[i].HostCount, vuln.HostsCount) + for _, vuln := range actual { + require.Contains(t, expected, hostCount{ + CVE: vuln.CVE.CVE, + HostCount: vuln.HostsCount, + }) } }