diff --git a/changes/18115-host_issues b/changes/18115-host_issues index b8abd9bd41..e2f6de702c 100644 --- a/changes/18115-host_issues +++ b/changes/18115-host_issues @@ -1,2 +1,4 @@ -* /api/latest/fleet/hosts now returns `critical_vulnerabilities_count` for premium users. This data is held in the new `host_issues` table. The failing policies total is updated in real-time, while the critical vulnerabilities total is updated every hour after vulnerabilities job. -* /api/latest/fleet/hosts can be sorted by total_issues_count by specifying `order_key=issues` query parameter. +* /api/latest/fleet/hosts and /api/latest/fleet/labels/:id/hosts now return `critical_vulnerabilities_count` for premium users. This data is held in the new `host_issues` table. The failing policies total is updated in real-time, while the critical vulnerabilities total is updated every hour after vulnerabilities job. +* /api/latest/fleet/hosts and /api/latest/fleet/labels/:id/hosts can be sorted by total_issues_count by specifying `order_key=issues` query parameter. +* /api/latest/hosts/:id and /api/latest/hosts/identifier/:identifier now return `critical_vulnerabilities_count` for premium users. +* For /api/latest/fleet/hosts, /api/latest/fleet/hosts/report, and /api/latest/fleet/labels/:id/hosts endpoints, the `disable_failing_policies` query parameter has been deprecated. Instead, use `disable_issues` to disable the failing policies and critical vulnerabilities counts. diff --git a/ee/server/service/hosts.go b/ee/server/service/hosts.go index 1e5fb95603..33b7959285 100644 --- a/ee/server/service/hosts.go +++ b/ee/server/service/hosts.go @@ -18,6 +18,7 @@ func (svc *Service) GetHost(ctx context.Context, id uint, opts fleet.HostDetailO // reuse GetHost, but include premium details opts.IncludeCVEScores = true opts.IncludePolicies = true + opts.IncludeCriticalVulnerabilitiesCount = true return svc.Service.GetHost(ctx, id, opts) } diff --git a/ee/server/service/teams.go b/ee/server/service/teams.go index bc05d8960a..a4a6b2993e 100644 --- a/ee/server/service/teams.go +++ b/ee/server/service/teams.go @@ -525,8 +525,8 @@ func (svc *Service) DeleteTeam(ctx context.Context, teamID uint) error { filter := fleet.TeamFilter{User: vc.User, IncludeObserver: true} opts := fleet.HostListOptions{ - TeamFilter: &teamID, - DisableFailingPolicies: true, // don't need to check policies for hosts that are being deleted + TeamFilter: &teamID, + DisableIssues: true, // don't need to check policies for hosts that are being deleted } hosts, err := svc.ds.ListHosts(ctx, filter, opts) diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index 3af708b8fe..849861b932 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -653,8 +653,9 @@ SELECT WHERE host_id = h.id ) AS additional, - COALESCE(failing_policies.count, 0) AS failing_policies_count, - COALESCE(failing_policies.count, 0) AS total_issues_count, + COALESCE(host_issues.failing_policies_count, 0) AS failing_policies_count, + COALESCE(host_issues.critical_vulnerabilities_count, 0) AS critical_vulnerabilities_count, + COALESCE(host_issues.total_issues_count, 0) AS total_issues_count, hoi.version AS orbit_version, hoi.desktop_version AS fleet_desktop_version, hoi.scripts_enabled AS scripts_enabled @@ -666,22 +667,14 @@ FROM LEFT JOIN host_updates hu ON (h.id = hu.host_id) LEFT JOIN host_disks hd ON hd.host_id = h.id LEFT JOIN host_orbit_info hoi ON hoi.host_id = h.id + LEFT JOIN host_issues ON h.id = host_issues.host_id ` + hostMDMJoin + ` - JOIN ( - SELECT - count(*) as count - FROM - policy_membership - WHERE - passes = 0 - AND host_id = ? - ) failing_policies WHERE h.id = ? LIMIT 1 ` - args := []interface{}{id, id} + args := []interface{}{id} var host fleet.Host err := sqlx.GetContext(ctx, ds.reader(ctx), &host, sqlStatement, args...) @@ -969,7 +962,7 @@ func (ds *Datastore) ListHosts(ctx context.Context, filter fleet.TeamFilter, opt ` } - if !opt.DisableFailingPolicies { + if !opt.DisableIssues { sql += `, COALESCE(host_issues.failing_policies_count, 0) AS failing_policies_count, COALESCE(host_issues.critical_vulnerabilities_count, 0) AS critical_vulnerabilities_count, @@ -1073,7 +1066,7 @@ func (ds *Datastore) applyHostFilters( } failingPoliciesJoin := "" - if !opt.DisableFailingPolicies { + if !opt.DisableIssues { failingPoliciesJoin = `LEFT JOIN host_issues ON h.id = host_issues.host_id` } @@ -1610,7 +1603,7 @@ func (ds *Datastore) CountHosts(ctx context.Context, filter fleet.TeamFilter, op opt.Page = 0 opt.PerPage = 0 // We don't need the issue counts of each host for counting hosts. - opt.DisableFailingPolicies = true + opt.DisableIssues = true var params []interface{} diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index d0062e00bf..d9881e17f1 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -3451,8 +3451,10 @@ func testHostsListFailingPolicies(t *testing.T, ds *Datastore) { h2 := hosts[1] assert.Zero(t, h1.HostIssues.FailingPoliciesCount) + assert.Zero(t, *h1.HostIssues.CriticalVulnerabilitiesCount) assert.Zero(t, h1.HostIssues.TotalIssuesCount) assert.Zero(t, h2.HostIssues.FailingPoliciesCount) + assert.Zero(t, *h2.HostIssues.CriticalVulnerabilitiesCount) assert.Zero(t, h2.HostIssues.TotalIssuesCount) require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), false)) @@ -3469,7 +3471,7 @@ func testHostsListFailingPolicies(t *testing.T, ds *Datastore) { require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false)) checkHostIssues(t, ds, hosts, filter, h1.ID, 1) - checkHostIssuesWithOpts(t, ds, hosts, filter, h1.ID, fleet.HostListOptions{DisableFailingPolicies: true}, 0) + checkHostIssuesWithOpts(t, ds, hosts, filter, h1.ID, fleet.HostListOptions{DisableIssues: true}, 0) } // This doesn't work when running the whole test suite, but helps inspect individual tests @@ -3580,7 +3582,7 @@ func checkHostIssuesWithOpts( assert.Equal(t, expected, foundHost.HostIssues.FailingPoliciesCount) assert.Equal(t, expected, foundHost.HostIssues.TotalIssuesCount) - if opts.DisableFailingPolicies { + if opts.DisableIssues { return } diff --git a/server/datastore/mysql/labels.go b/server/datastore/mysql/labels.go index cf0904fc48..2f654ebe33 100644 --- a/server/datastore/mysql/labels.go +++ b/server/datastore/mysql/labels.go @@ -541,18 +541,16 @@ func (ds *Datastore) ListHostsInLabel(ctx context.Context, filter fleet.TeamFilt %s %s ` - failingPoliciesSelect := `, - COALESCE(failing_policies.count, 0) AS failing_policies_count, - COALESCE(failing_policies.count, 0) AS total_issues_count + failingIssuesSelect := `, + COALESCE(host_issues.failing_policies_count, 0) AS failing_policies_count, + COALESCE(host_issues.critical_vulnerabilities_count, 0) AS critical_vulnerabilities_count, + COALESCE(host_issues.total_issues_count, 0) AS total_issues_count ` - failingPoliciesJoin := `LEFT JOIN ( - SELECT host_id, count(*) as count FROM policy_membership WHERE passes = 0 - GROUP BY host_id - ) as failing_policies ON (h.id=failing_policies.host_id)` + failingIssuesJoin := `LEFT JOIN host_issues ON h.id = host_issues.host_id` - if opt.DisableFailingPolicies { - failingPoliciesSelect = "" - failingPoliciesJoin = "" + if opt.DisableIssues { + failingIssuesSelect = "" + failingIssuesJoin = "" } deviceMappingJoin := fmt.Sprintf(`LEFT JOIN ( @@ -573,7 +571,9 @@ func (ds *Datastore) ListHostsInLabel(ctx context.Context, filter fleet.TeamFilt COALESCE(dm.device_mapping, 'null') as device_mapping` } - query := fmt.Sprintf(queryFmt, hostMDMSelect, failingPoliciesSelect, deviceMappingSelect, hostMDMJoin, failingPoliciesJoin, deviceMappingJoin) + query := fmt.Sprintf( + queryFmt, hostMDMSelect, failingIssuesSelect, deviceMappingSelect, hostMDMJoin, failingIssuesJoin, deviceMappingJoin, + ) query, params, err := ds.applyHostLabelFilters(ctx, filter, lid, query, opt) if err != nil { @@ -661,6 +661,9 @@ func (ds *Datastore) applyHostLabelFilters(ctx context.Context, filter fleet.Tea // TODO: should search columns include display_name (requires join to host_display_names)? query, whereParams, _ = hostSearchLike(query, whereParams, opt.MatchQuery, hostSearchColumns...) + if opt.ListOptions.OrderKey == "issues" { + opt.ListOptions.OrderKey = "host_issues.total_issues_count" + } query, whereParams = appendListOptionsWithCursorToSQL(query, whereParams, &opt.ListOptions) return query, append(joinParams, whereParams...), nil } diff --git a/server/datastore/mysql/labels_test.go b/server/datastore/mysql/labels_test.go index eaa601f9ef..a5605a3eb1 100644 --- a/server/datastore/mysql/labels_test.go +++ b/server/datastore/mysql/labels_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/fleetdm/fleet/v4/pkg/optjson" + "github.com/fleetdm/fleet/v4/server/contexts/license" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/test" @@ -65,7 +66,7 @@ func TestLabels(t *testing.T) { {"RecordNonExistentQueryLabelExecution", testLabelsRecordNonexistentQueryLabelExecution}, {"DeleteLabel", testDeleteLabel}, {"LabelsSummary", testLabelsSummary}, - {"ListHostsInLabelFailingPolicies", testListHostsInLabelFailingPolicies}, + {"ListHostsInLabelIssues", testListHostsInLabelIssues}, {"ListHostsInLabelDiskEncryptionStatus", testListHostsInLabelDiskEncryptionStatus}, {"HostMemberOfAllLabels", testHostMemberOfAllLabels}, {"ListHostsInLabelOSSettings", testLabelsListHostsInLabelOSSettings}, @@ -952,7 +953,7 @@ func testLabelsSummary(t *testing.T, db *Datastore) { require.Len(t, ls, 5) } -func testListHostsInLabelFailingPolicies(t *testing.T, ds *Datastore) { +func testListHostsInLabelIssues(t *testing.T, ds *Datastore) { user1 := test.NewUser(t, ds, "Alice", "alice@example.com", true) for i := 0; i < 10; i++ { _, err := ds.NewHost(context.Background(), &fleet.Host{ @@ -1003,30 +1004,100 @@ func testListHostsInLabelFailingPolicies(t *testing.T, ds *Datastore) { h2 := hosts[1] assert.Zero(t, h1.HostIssues.FailingPoliciesCount) + assert.Zero(t, *h1.HostIssues.CriticalVulnerabilitiesCount) assert.Zero(t, h1.HostIssues.TotalIssuesCount) assert.Zero(t, h2.HostIssues.FailingPoliciesCount) + assert.Zero(t, *h2.HostIssues.CriticalVulnerabilitiesCount) assert.Zero(t, h2.HostIssues.TotalIssuesCount) require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(true)}, time.Now(), false)) require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(false), p2.ID: ptr.Bool(false)}, time.Now(), false)) - checkLabelHostIssues(t, ds, hosts, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 2) + checkLabelHostIssues(t, ds, hosts, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 2, 0) + + // Add a critical vulnerability + // seed software + software := []fleet.Software{ + {Name: "foo0", Version: "0", Source: "chrome_extensions"}, // vulnerable + {Name: "foo1", Version: "1", Source: "chrome_extensions"}, + {Name: "foo2", Version: "2", Source: "chrome_extensions"}, + {Name: "foo3", Version: "3", Source: "chrome_extensions"}, + {Name: "foo4", Version: "4", Source: "chrome_extensions"}, // vulnerable + {Name: "foo5", Version: "5", Source: "chrome_extensions"}, // vulnerable + {Name: "foo6", Version: "6", Source: "chrome_extensions"}, // vulnerable + {Name: "foo7", Version: "7", Source: "chrome_extensions"}, // vulnerable + } + + for i := 0; i < len(software); i++ { + _, err := ds.UpdateHostSoftware(context.Background(), hosts[i].ID, software[:i+1]) + require.NoError(t, err) + } + + softwareItems := make([]fleet.Software, 0, len(software)) + ctx := context.Background() + require.NoError(t, sqlx.SelectContext(ctx, ds.reader(ctx), &softwareItems, "SELECT id, version FROM software")) + require.Len(t, softwareItems, len(software)) + + for _, sw := range softwareItems { + _, err := ds.InsertSoftwareVulnerability( + context.Background(), fleet.SoftwareVulnerability{ + CVE: fmt.Sprintf("CVE-%s", sw.Version), + SoftwareID: sw.ID, + }, fleet.NVDSource, + ) + require.NoError(t, err) + } + require.NoError( + t, ds.InsertCVEMeta( + ctx, []fleet.CVEMeta{ + { + CVE: "CVE-0", + CVSSScore: ptr.Float64(2 * criticalCVSSScoreCutoff), + }, + { + CVE: "CVE-3", + CVSSScore: ptr.Float64(criticalCVSSScoreCutoff), // not critical + }, + { + CVE: "CVE-4", + CVSSScore: ptr.Float64(criticalCVSSScoreCutoff + 0.001), + }, + { + CVE: "CVE-5", + CVSSScore: ptr.Float64(criticalCVSSScoreCutoff + 0.01), + }, + { + CVE: "CVE-6", + CVSSScore: ptr.Float64(criticalCVSSScoreCutoff + 0.1), + }, + { + CVE: "CVE-7", + CVSSScore: ptr.Float64(criticalCVSSScoreCutoff + 1), + }, + }, + ), + ) + // Populate critical vulnerabilities, which can be done with premium license. + ctx = license.NewContext(ctx, &fleet.LicenseInfo{Tier: fleet.TierPremium}) + assert.NoError(t, ds.UpdateHostIssuesVulnerabilities(ctx)) + checkLabelHostIssues(t, ds, hosts, l1.ID, filter, hosts[6].ID, fleet.HostListOptions{}, 0, 4) require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(false)}, time.Now(), false)) - checkLabelHostIssues(t, ds, hosts, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 1) + checkLabelHostIssues(t, ds, hosts, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 1, 1) require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2, map[uint]*bool{p.ID: ptr.Bool(true), p2.ID: ptr.Bool(true)}, time.Now(), false)) - checkLabelHostIssues(t, ds, hosts, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 0) + checkLabelHostIssues(t, ds, hosts, l1.ID, filter, h2.ID, fleet.HostListOptions{}, 0, 1) require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false)) - checkLabelHostIssues(t, ds, hosts, l1.ID, filter, h1.ID, fleet.HostListOptions{}, 1) + checkLabelHostIssues(t, ds, hosts, l1.ID, filter, h1.ID, fleet.HostListOptions{}, 1, 1) - checkLabelHostIssues(t, ds, hosts, l1.ID, filter, h1.ID, fleet.HostListOptions{DisableFailingPolicies: true}, 0) + checkLabelHostIssues(t, ds, hosts, l1.ID, filter, h1.ID, fleet.HostListOptions{DisableIssues: true}, 0, 0) + checkLabelHostIssues(t, ds, hosts, l1.ID, filter, hosts[6].ID, fleet.HostListOptions{DisableIssues: true}, 0, 0) } func checkLabelHostIssues( t *testing.T, ds *Datastore, hosts []*fleet.Host, lid uint, filter fleet.TeamFilter, hid uint, opts fleet.HostListOptions, - expected uint64, + failingPoliciesExpected uint64, criticalVulnerabilitiesExpected uint64, ) { hosts = listHostsInLabelCheckCount(t, ds, filter, lid, opts, 10) foundH2 := false @@ -1039,17 +1110,21 @@ func checkLabelHostIssues( } } require.True(t, foundH2) - assert.Equal(t, expected, foundHost.HostIssues.FailingPoliciesCount) - assert.Equal(t, expected, foundHost.HostIssues.TotalIssuesCount) + assert.Equal(t, failingPoliciesExpected, foundHost.HostIssues.FailingPoliciesCount) - if opts.DisableFailingPolicies { + if opts.DisableIssues { + assert.Nil(t, foundHost.HostIssues.CriticalVulnerabilitiesCount) + assert.Zero(t, foundHost.HostIssues.TotalIssuesCount) return } + assert.Equal(t, criticalVulnerabilitiesExpected, *foundHost.HostIssues.CriticalVulnerabilitiesCount) + assert.Equal(t, failingPoliciesExpected+criticalVulnerabilitiesExpected, foundHost.HostIssues.TotalIssuesCount) hostById, err := ds.Host(context.Background(), hid) require.NoError(t, err) - assert.Equal(t, expected, hostById.HostIssues.FailingPoliciesCount) - assert.Equal(t, expected, hostById.HostIssues.TotalIssuesCount) + assert.Equal(t, failingPoliciesExpected, hostById.HostIssues.FailingPoliciesCount) + assert.Equal(t, failingPoliciesExpected+criticalVulnerabilitiesExpected, hostById.HostIssues.TotalIssuesCount) + assert.Equal(t, foundHost.HostIssues.CriticalVulnerabilitiesCount, hostById.HostIssues.CriticalVulnerabilitiesCount) } func testListHostsInLabelDiskEncryptionStatus(t *testing.T, ds *Datastore) { diff --git a/server/datastore/mysql/migrations/tables/20240613172616_HostIssuesTable.go b/server/datastore/mysql/migrations/tables/20240613172616_HostIssuesTable.go index f831d75abe..3ab786c408 100644 --- a/server/datastore/mysql/migrations/tables/20240613172616_HostIssuesTable.go +++ b/server/datastore/mysql/migrations/tables/20240613172616_HostIssuesTable.go @@ -25,6 +25,19 @@ func Up_20240613172616(tx *sql.Tx) error { if err != nil { return fmt.Errorf("failed to create host_issues table: %w", err) } + + // Now, populate the table with failing_policies_counts + _, err = tx.Exec( + `INSERT INTO host_issues (host_id, failing_policies_count, total_issues_count) + SELECT pm.host_id, COALESCE(SUM(!pm.passes), 0), COALESCE(SUM(!pm.passes), 0) + FROM policy_membership pm + WHERE pm.passes = 0 + GROUP BY pm.host_id`, + ) + if err != nil { + return fmt.Errorf("failed to populate host_issues table: %w", err) + } + return nil } diff --git a/server/datastore/mysql/migrations/tables/20240613172616_HostIssuesTable_test.go b/server/datastore/mysql/migrations/tables/20240613172616_HostIssuesTable_test.go index 4f3cd7cc9f..688084aa91 100644 --- a/server/datastore/mysql/migrations/tables/20240613172616_HostIssuesTable_test.go +++ b/server/datastore/mysql/migrations/tables/20240613172616_HostIssuesTable_test.go @@ -10,15 +10,23 @@ import ( func TestUp_20240613172616(t *testing.T) { db := applyUpToPrev(t) - applyNext(t, db) - hostID := uint(12) - - insertStmt := `INSERT INTO host_issues (host_id, failing_policies_count, critical_vulnerabilities_count, total_issues_count) VALUES (?, ?, ?, ?)` - _, err := db.Exec(insertStmt, hostID, 1, 2, 3) + res, err := db.Exec( + ` + INSERT INTO policies (name, query, description, checksum) + VALUES ('test_policy', "", "", "abc")`, + ) require.NoError(t, err) - _, err = db.Exec(insertStmt, hostID, 4, 5, 6) - require.ErrorContains(t, err, "Error 1062") + policyID, err := res.LastInsertId() + require.NoError(t, err) + + _, err = db.Exec( + `INSERT INTO policy_membership (policy_id, host_id, passes) VALUES (?, ?, ?)`, + policyID, 1, 0, + ) + require.NoError(t, err) + + applyNext(t, db) type issues struct { HostID uint `db:"host_id"` @@ -31,6 +39,23 @@ func TestUp_20240613172616(t *testing.T) { var result issues selectStmt := `SELECT * from host_issues WHERE host_id = ?` + err = db.Get(&result, selectStmt, 1) + require.NoError(t, err) + assert.Equal(t, uint(1), result.HostID) + assert.Equal(t, uint(1), result.FailingPoliciesCount) + assert.Equal(t, uint(0), result.CriticalVulnerabilitiesCount) + assert.Equal(t, uint(1), result.TotalIssuesCount) + assert.NotZero(t, result.CreatedAt) + assert.Equal(t, result.CreatedAt, result.UpdatedAt) + + hostID := uint(12) + + insertStmt := `INSERT INTO host_issues (host_id, failing_policies_count, critical_vulnerabilities_count, total_issues_count) VALUES (?, ?, ?, ?)` + _, err = db.Exec(insertStmt, hostID, 1, 2, 3) + require.NoError(t, err) + _, err = db.Exec(insertStmt, hostID, 4, 5, 6) + require.ErrorContains(t, err, "Error 1062") + err = db.Get(&result, selectStmt, hostID) require.NoError(t, err) assert.Equal(t, hostID, result.HostID) diff --git a/server/fleet/hosts.go b/server/fleet/hosts.go index 58119cb43b..2e8c385d3d 100644 --- a/server/fleet/hosts.go +++ b/server/fleet/hosts.go @@ -155,7 +155,7 @@ type HostListOptions struct { OSVersionFilter *string OSVersionIDFilter *uint - DisableFailingPolicies bool + DisableIssues bool // MacOSSettingsFilter filters the hosts by the status of MDM configuration profiles // applied to the hosts. @@ -221,7 +221,7 @@ func (h HostListOptions) Empty() bool { h.OSIDFilter == nil && h.OSNameFilter == nil && h.OSVersionFilter == nil && - h.DisableFailingPolicies == false && + h.DisableIssues == false && h.MacOSSettingsFilter == "" && h.MacOSSettingsDiskEncryptionFilter == "" && h.MDMBootstrapPackageFilter == nil && @@ -1191,9 +1191,10 @@ type OSVersion struct { } type HostDetailOptions struct { - IncludeCVEScores bool - IncludePolicies bool - ExcludeSoftware bool + IncludeCVEScores bool + IncludeCriticalVulnerabilitiesCount bool + IncludePolicies bool + ExcludeSoftware bool } // EnrollHostLimiter defines the methods to support enforcement of enrolled diff --git a/server/service/hosts.go b/server/service/hosts.go index b884cd9092..7340c63244 100644 --- a/server/service/hosts.go +++ b/server/service/hosts.go @@ -192,11 +192,18 @@ func (svc *Service) ListHosts(ctx context.Context, opt fleet.HostListOptions) ([ return nil, err } - if !opt.DisableFailingPolicies && !premiumLicense { + // If issues are enabled, we need to remove the critical vulnerabilities count for non-premium license. + // If issues are disabled, we need to explicitly set the critical vulnerabilities count to 0 for premium license. + if !opt.DisableIssues && !premiumLicense { // Remove critical vulnerabilities count if not premium license for _, host := range hosts { host.HostIssues.CriticalVulnerabilitiesCount = nil } + } else if opt.DisableIssues && premiumLicense { + var zero uint64 + for _, host := range hosts { + host.HostIssues.CriticalVulnerabilitiesCount = &zero + } } if opt.PopulateSoftware { @@ -331,7 +338,7 @@ func (svc *Service) DeleteHosts(ctx context.Context, ids []uint, filter *map[str if opts == nil { opts = &fleet.HostListOptions{} } - opts.DisableFailingPolicies = true // don't check policies for hosts that are about to be deleted + opts.DisableIssues = true // don't check policies for hosts that are about to be deleted hostIDs, _, hosts, err := svc.hostIDsAndNamesFromFilters(ctx, *opts, lid) if err != nil { return err @@ -532,6 +539,9 @@ func (svc *Service) GetHost(ctx context.Context, id uint, opts fleet.HostDetailO if err != nil { return nil, ctxerr.Wrap(ctx, err, "get host") } + if !opts.IncludeCriticalVulnerabilitiesCount { + host.HostIssues.CriticalVulnerabilitiesCount = nil + } if !alreadyAuthd { // Authorize again with team loaded now that we have team_id @@ -1309,7 +1319,7 @@ func (svc *Service) hostIDsAndNamesFromFilters(ctx context.Context, opt fleet.Ho if lid != nil { hosts, err = svc.ds.ListHostsInLabel(ctx, filter, *lid, opt) } else { - opt.DisableFailingPolicies = true // intentionally ignore failing policies + opt.DisableIssues = true // intentionally ignore failing policies hosts, err = svc.ds.ListHosts(ctx, filter, opt) } if err != nil { diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index b1f53b9db4..99bdca4188 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -1643,12 +1643,22 @@ func (s *integrationTestSuite) TestListHosts() { assert.Nil(t, resp.Hosts[0].HostIssues.CriticalVulnerabilitiesCount) resp = listHostsResponse{} + // disable_failing_policies has been deprecated and is no longer documented; it is an alias for disable_issues s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_version_id", fmt.Sprint(fooV1ID), "disable_failing_policies", "true") require.Len(t, resp.Hosts, 1) assert.Zero(t, resp.Hosts[0].HostIssues.FailingPoliciesCount) assert.Zero(t, resp.Hosts[0].HostIssues.TotalIssuesCount) assert.Nil(t, resp.Hosts[0].HostIssues.CriticalVulnerabilitiesCount) + resp = listHostsResponse{} + s.DoJSON( + "GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_version_id", fmt.Sprint(fooV1ID), "disable_issues", "true", + ) + require.Len(t, resp.Hosts, 1) + assert.Zero(t, resp.Hosts[0].HostIssues.FailingPoliciesCount) + assert.Zero(t, resp.Hosts[0].HostIssues.TotalIssuesCount) + assert.Nil(t, resp.Hosts[0].HostIssues.CriticalVulnerabilitiesCount) + // filter by MDM criteria without any host having such information resp = listHostsResponse{} s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "mdm_id", fmt.Sprint(999)) @@ -10220,9 +10230,12 @@ func (s *integrationTestSuite) TestHostsReportWithPolicyResults() { require.Equal(t, row[issuesIdx], "1") } - // Running with disable_failing_policies=true disable the counting of failed policies for a host. + // Running with disable_issues=true (which overrides disable_failing_policies=false) disable the counting of failed policies for a host. // Thus, all "issues" values should be 0. - res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, "format", "csv", "disable_failing_policies", "true") + res = s.DoRaw( + "GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, "format", "csv", "disable_failing_policies", "false", "disable_issues", + "true", + ) rows2, err := csv.NewReader(res.Body).ReadAll() res.Body.Close() require.NoError(t, err) @@ -10291,8 +10304,10 @@ func (s *integrationTestSuite) TestHostsReportWithPolicyResults() { res.Body.Close() require.NoError(t, err) tc.checkRows(t, rows) - // Test the same with "disable_failing_policies=true" which should not change the result. - res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, append(tc.args, "format", "csv", "disable_failing_policies", "true")...) + // Test the same with "disable_issues=true" which should not change the result. + res = s.DoRaw( + "GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, append(tc.args, "format", "csv", "disable_issues", "true")..., + ) rows, err = csv.NewReader(res.Body).ReadAll() res.Body.Close() require.NoError(t, err) diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index dff33643a0..26a2402d39 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -26,6 +26,7 @@ import ( "github.com/fleetdm/fleet/v4/ee/server/calendar" "github.com/fleetdm/fleet/v4/pkg/optjson" "github.com/fleetdm/fleet/v4/server/config" + "github.com/fleetdm/fleet/v4/server/contexts/license" "github.com/fleetdm/fleet/v4/server/cron" "github.com/fleetdm/fleet/v4/server/datastore/mysql" "github.com/fleetdm/fleet/v4/server/datastore/redis/redistest" @@ -3388,9 +3389,38 @@ func (s *integrationEnterpriseTestSuite) TestListHosts() { require.Equal(t, uint(0), summaryResp.TotalsHostsCount) require.Nil(t, summaryResp.LowDiskSpaceCount) + // Add a failing policy + ctx := context.Background() + qr, err := s.ds.NewQuery( + ctx, &fleet.Query{ + Name: "TestQueryEnterpriseTestListHosts", + Description: "Some description", + Query: "select * from osquery;", + ObserverCanRun: true, + Logging: fleet.LoggingSnapshot, + }, + ) + require.NoError(t, err) + + // add a global policy + gpParams := globalPolicyRequest{ + QueryID: &qr.ID, + Resolution: "some global resolution", + } + gpResp := globalPolicyResponse{} + s.DoJSON("POST", "/api/latest/fleet/policies", gpParams, http.StatusOK, &gpResp) + require.NotNil(t, gpResp.Policy) + + // add a failing policy execution + require.NoError( + t, s.ds.RecordPolicyQueryExecutions( + ctx, host1, + map[uint]*bool{gpResp.Policy.ID: ptr.Bool(false)}, time.Now(), false, + ), + ) + // populate software for hosts now := time.Now() - software := []fleet.Software{ {Name: "foo", Version: "0.0.1", Source: "chrome_extensions"}, } @@ -3408,7 +3438,7 @@ func (s *integrationEnterpriseTestSuite) TestListHosts() { vulnMeta := []fleet.CVEMeta{{ CVE: "cve-123-123-123", - CVSSScore: ptr.Float64(5.4), + CVSSScore: ptr.Float64(9.8), EPSSProbability: ptr.Float64(0.5), CISAKnownExploit: ptr.Bool(true), Published: &now, @@ -3416,6 +3446,8 @@ func (s *integrationEnterpriseTestSuite) TestListHosts() { }} require.NoError(t, s.ds.InsertCVEMeta(context.Background(), vulnMeta)) + ctx = license.NewContext(ctx, &fleet.LicenseInfo{Tier: fleet.TierPremium}) + require.NoError(t, s.ds.UpdateHostIssuesVulnerabilities(ctx)) resp = listHostsResponse{} s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "populate_software", "true") @@ -3431,6 +3463,13 @@ func (s *integrationEnterpriseTestSuite) TestListHosts() { require.Equal(t, &vulnMeta[0].EPSSProbability, h.Software[0].Vulnerabilities[0].EPSSProbability) require.Equal(t, &vulnMeta[0].CISAKnownExploit, h.Software[0].Vulnerabilities[0].CISAKnownExploit) require.Equal(t, &s, h.Software[0].Vulnerabilities[0].Description) + assert.Equal(t, uint64(1), h.HostIssues.FailingPoliciesCount) + assert.Equal(t, uint64(1), *h.HostIssues.CriticalVulnerabilitiesCount) + assert.Equal(t, uint64(2), h.HostIssues.TotalIssuesCount) + } else { + assert.Zero(t, h.HostIssues.FailingPoliciesCount) + assert.Zero(t, *h.HostIssues.CriticalVulnerabilitiesCount) + assert.Zero(t, h.HostIssues.TotalIssuesCount) } } @@ -3440,6 +3479,44 @@ func (s *integrationEnterpriseTestSuite) TestListHosts() { for _, h := range resp.Hosts { require.Empty(t, h.Software) } + + // Test host list from labels endpoint + // First assign label to hosts + allHostsLabel, err := s.ds.GetLabelSpec(ctx, "All hosts") + require.NoError(t, err) + for _, h := range resp.Hosts { + err = s.ds.RecordLabelQueryExecutions( + context.Background(), h.Host, map[uint]*bool{allHostsLabel.ID: ptr.Bool(true)}, time.Now(), false, + ) + require.NoError(t, err) + } + + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/labels/%d/hosts", allHostsLabel.ID), nil, http.StatusOK, &resp) + assert.Len(t, resp.Hosts, 3) + for _, h := range resp.Hosts { + if h.ID == host1.ID { + assert.Equal(t, uint64(1), h.HostIssues.FailingPoliciesCount) + assert.Equal(t, uint64(1), *h.HostIssues.CriticalVulnerabilitiesCount) + assert.Equal(t, uint64(2), h.HostIssues.TotalIssuesCount) + } else { + assert.Zero(t, h.HostIssues.FailingPoliciesCount) + assert.Zero(t, *h.HostIssues.CriticalVulnerabilitiesCount) + assert.Zero(t, h.HostIssues.TotalIssuesCount) + } + } + + // Test ordering by issues + s.DoJSON( + "GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "order_key", "issues", + ) // defaults to ascending order (lowest issues to most issues) + require.Len(t, resp.Hosts, 3) + assert.Equal(t, host1.ID, resp.Hosts[2].ID) + s.DoJSON( + "GET", fmt.Sprintf("/api/latest/fleet/labels/%d/hosts", allHostsLabel.ID), nil, http.StatusOK, &resp, "order_key", "issues", + "order_direction", "desc", + ) + require.Len(t, resp.Hosts, 3) + assert.Equal(t, host1.ID, resp.Hosts[0].ID) } func (s *integrationEnterpriseTestSuite) TestHostHealth() { diff --git a/server/service/labels.go b/server/service/labels.go index eda8b2fa26..c79e8fadb9 100644 --- a/server/service/labels.go +++ b/server/service/labels.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + "github.com/fleetdm/fleet/v4/server/contexts/license" "github.com/fleetdm/fleet/v4/server/contexts/viewer" "github.com/fleetdm/fleet/v4/server/fleet" ) @@ -389,7 +390,26 @@ func (svc *Service) ListHostsInLabel(ctx context.Context, lid uint, opt fleet.Ho } filter := fleet.TeamFilter{User: vc.User, IncludeObserver: true} - return svc.ds.ListHostsInLabel(ctx, filter, lid, opt) + hosts, err := svc.ds.ListHostsInLabel(ctx, filter, lid, opt) + if err != nil { + return nil, err + } + + premiumLicense := license.IsPremium(ctx) + // If issues are enabled, we need to remove the critical vulnerabilities count for non-premium license. + // If issues are disabled, we need to explicitly set the critical vulnerabilities count to 0 for premium license. + if !opt.DisableIssues && !premiumLicense { + // Remove critical vulnerabilities count if not premium license + for _, host := range hosts { + host.HostIssues.CriticalVulnerabilitiesCount = nil + } + } else if opt.DisableIssues && premiumLicense { + var zero uint64 + for _, host := range hosts { + host.HostIssues.CriticalVulnerabilitiesCount = &zero + } + } + return hosts, nil } //////////////////////////////////////////////////////////////////////////////// diff --git a/server/service/transport.go b/server/service/transport.go index b32c939579..4e9d9df562 100644 --- a/server/service/transport.go +++ b/server/service/transport.go @@ -368,8 +368,24 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error) ) } + // disable_failing_policies is a deprecated parameter and an alias for disable_issues + // disable_issues is the new parameter name, which takes precedence over disable_failing_policies disableFailingPolicies := r.URL.Query().Get("disable_failing_policies") - if disableFailingPolicies != "" { + disableIssues := r.URL.Query().Get("disable_issues") + if disableIssues != "" { + boolVal, err := strconv.ParseBool(disableIssues) + if err != nil { + return hopt, ctxerr.Wrap( + r.Context(), badRequest( + fmt.Sprintf( + "Invalid disable_issues: %s", + disableIssues, + ), + ), + ) + } + hopt.DisableIssues = boolVal + } else if disableFailingPolicies != "" { boolVal, err := strconv.ParseBool(disableFailingPolicies) if err != nil { return hopt, ctxerr.Wrap( @@ -381,7 +397,14 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error) ), ) } - hopt.DisableFailingPolicies = boolVal + hopt.DisableIssues = boolVal + } + if hopt.DisableIssues && r.URL.Query().Get("order_key") == "issues" { + return hopt, ctxerr.Wrap( + r.Context(), badRequest( + "Invalid order_key (issues cannot be ordered when they are disabled)", + ), + ) } deviceMapping := r.URL.Query().Get("device_mapping") diff --git a/server/service/transport_test.go b/server/service/transport_test.go index f253adccd1..e15b30bfd1 100644 --- a/server/service/transport_test.go +++ b/server/service/transport_test.go @@ -155,7 +155,7 @@ func TestHostListOptionsFromRequest(t *testing.T) { "all params defined": { url: "/foo?order_key=foo&order_direction=asc&page=10&per_page=1&device_mapping=T&additional_info_filters" + "=filter1,filter2&status=new&team_id=2&policy_id=3&policy_response=passing&software_id=4&os_id=5" + - "&os_name=osName&os_version=osVersion&os_version_id=5&disable_failing_policies=1&macos_settings=verified" + + "&os_name=osName&os_version=osVersion&os_version_id=5&disable_failing_policies=0&disable_issues=1&macos_settings=verified" + "&macos_settings_disk_encryption=enforcing&os_settings=pending&os_settings_disk_encryption=failed" + "&bootstrap_package=installed&mdm_id=6&mdm_name=mdmName&mdm_enrollment_status=automatic" + "&munki_issue_id=7&low_disk_space=99&vulnerability=CVE-2023-42887&populate_policies=true", @@ -177,7 +177,7 @@ func TestHostListOptionsFromRequest(t *testing.T) { OSVersionIDFilter: ptr.Uint(5), OSNameFilter: ptr.String("osName"), OSVersionFilter: ptr.String("osVersion"), - DisableFailingPolicies: true, + DisableIssues: true, MacOSSettingsFilter: fleet.OSSettingsVerified, MacOSSettingsDiskEncryptionFilter: fleet.DiskEncryptionEnforcing, OSSettingsFilter: fleet.OSSettingsPending, @@ -239,6 +239,18 @@ func TestHostListOptionsFromRequest(t *testing.T) { url: "/foo?disable_failing_policies=foo", errorMessage: "Invalid disable_failing_policies", }, + "error in disable_issues": { + url: "/foo?disable_issues=foo", + errorMessage: "Invalid disable_issues", + }, + "error in issues order key when disable_issues is set": { + url: "/foo?disable_issues=true&order_key=issues", + errorMessage: "Invalid order_key", + }, + "error in issues order key when disable_failing_policies is set": { + url: "/foo?disable_failing_policies=true&order_key=issues", + errorMessage: "Invalid order_key", + }, "error in device_mapping": { url: "/foo?device_mapping=foo", errorMessage: "Invalid device_mapping",