From 6b128dd455a11d43df2cd22fcf38b933e0b5f0c7 Mon Sep 17 00:00:00 2001 From: Martin Angers Date: Wed, 6 Dec 2023 14:59:00 -0500 Subject: [PATCH] Allow filtering hosts by `software_version_id` and `software_title_id`. (#15433) --- changes/issue-15345-filter-hosts-by-software | 1 + server/datastore/mysql/hosts.go | 23 ++- server/datastore/mysql/hosts_test.go | 49 +++++- server/fleet/hosts.go | 12 ++ server/service/hosts.go | 41 ++++- server/service/integration_core_test.go | 173 +++++++++++++++++-- server/service/transport.go | 47 ++++- server/service/transport_test.go | 51 +++++- 8 files changed, 360 insertions(+), 37 deletions(-) create mode 100644 changes/issue-15345-filter-hosts-by-software diff --git a/changes/issue-15345-filter-hosts-by-software b/changes/issue-15345-filter-hosts-by-software new file mode 100644 index 0000000000..57b8f0651d --- /dev/null +++ b/changes/issue-15345-filter-hosts-by-software @@ -0,0 +1 @@ +* Added ability to filter hosts by `software_version_id` and `software_title_id` for the "list hosts", "count hosts" and "get hosts report in CSV" endpoints. diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index 461ddaa8da..5f7a2b2667 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -943,9 +943,20 @@ func (ds *Datastore) applyHostFilters( } softwareFilter := "TRUE" - if opt.SoftwareIDFilter != nil { + var softwareIDFilter *uint + if opt.SoftwareVersionIDFilter != nil { + softwareIDFilter = opt.SoftwareVersionIDFilter + } else if opt.SoftwareIDFilter != nil { + softwareIDFilter = opt.SoftwareIDFilter + } + if softwareIDFilter != nil { softwareFilter = "EXISTS (SELECT 1 FROM host_software hs WHERE hs.host_id = h.id AND hs.software_id = ?)" - params = append(params, opt.SoftwareIDFilter) + params = append(params, *softwareIDFilter) + } else if opt.SoftwareTitleIDFilter != nil { + // software (version) ID filter is mutually exclusive with software title ID + // so we're reusing the same filter to avoid adding unnecessary conditions. + softwareFilter = "EXISTS (SELECT 1 FROM host_software hs INNER JOIN software sw ON hs.software_id = sw.id WHERE hs.host_id = h.id AND sw.title_id = ?)" + params = append(params, *opt.SoftwareTitleIDFilter) } failingPoliciesJoin := "" @@ -1268,7 +1279,7 @@ func (ds *Datastore) filterHostsByOSSettingsStatus(sql string, opt fleet.HostLis WHEN (%s) THEN 'bitlocker_pending' WHEN (%s) THEN - 'bitlocker_failed' + 'bitlocker_failed' ELSE '' END`, @@ -1280,11 +1291,11 @@ func (ds *Datastore) filterHostsByOSSettingsStatus(sql string, opt fleet.HostLis } whereWindows += fmt.Sprintf(` AND ( - CASE (%s) + CASE (%s) WHEN 'profiles_failed' THEN 'failed' WHEN 'profiles_pending' THEN ( - CASE (%s) + CASE (%s) WHEN 'bitlocker_failed' THEN 'failed' ELSE @@ -1310,7 +1321,7 @@ func (ds *Datastore) filterHostsByOSSettingsStatus(sql string, opt fleet.HostLis ELSE 'verified' END) - ELSE + ELSE REPLACE((%s), 'bitlocker_', '') END) = ?`, profilesStatus, bitlockerStatus, bitlockerStatus, bitlockerStatus, bitlockerStatus) diff --git a/server/datastore/mysql/hosts_test.go b/server/datastore/mysql/hosts_test.go index 3202e65ded..34ee465729 100644 --- a/server/datastore/mysql/hosts_test.go +++ b/server/datastore/mysql/hosts_test.go @@ -2767,16 +2767,63 @@ func testHostsListBySoftware(t *testing.T, ds *Datastore) { } host1 := hosts[0] host2 := hosts[1] + host3 := hosts[2] _, err := ds.UpdateHostSoftware(context.Background(), host1.ID, software) require.NoError(t, err) _, err = ds.UpdateHostSoftware(context.Background(), host2.ID, software) require.NoError(t, err) + // host 3 only has foo v0.0.3 + _, err = ds.UpdateHostSoftware(context.Background(), host3.ID, software[1:2]) + require.NoError(t, err) + + // reconcile software, will sync software titles + err = ds.ReconcileSoftwareTitles(context.Background()) + require.NoError(t, err) + + var fooV002ID uint + ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error { + return sqlx.GetContext(context.Background(), q, &fooV002ID, + "SELECT id FROM software WHERE name = ? AND source = ? AND version = ?", "foo", "chrome_extensions", "0.0.2") + }) + + var fooTitleID uint + ExecAdhocSQL(t, ds, func(q sqlx.ExtContext) error { + return sqlx.GetContext(context.Background(), q, &fooTitleID, + "SELECT id FROM software_titles WHERE name = ? AND source = ?", "foo", "chrome_extensions") + }) require.NoError(t, ds.LoadHostSoftware(context.Background(), host1, false)) require.NoError(t, ds.LoadHostSoftware(context.Background(), host2, false)) - hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareIDFilter: &host1.Software[0].ID}, 2) + // software_id is foo v0.0.2 + hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareIDFilter: &fooV002ID}, 2) require.Len(t, hosts, 2) + got := []uint{hosts[0].ID, hosts[1].ID} + require.ElementsMatch(t, []uint{host1.ID, host2.ID}, got) + + // software_version_id is foo v0.0.2 (works exacty the same) + hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareVersionIDFilter: &fooV002ID}, 2) + require.Len(t, hosts, 2) + got = []uint{hosts[0].ID, hosts[1].ID} + require.ElementsMatch(t, []uint{host1.ID, host2.ID}, got) + + // unknown software_id + hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareIDFilter: ptr.Uint(fooV002ID + 100)}, 0) + require.Len(t, hosts, 0) + + // unknown software_version_id + hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareVersionIDFilter: ptr.Uint(fooV002ID + 100)}, 0) + require.Len(t, hosts, 0) + + // software_title_id is foo (any version) + hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareTitleIDFilter: &fooTitleID}, 3) + require.Len(t, hosts, 3) + got = []uint{hosts[0].ID, hosts[1].ID, hosts[2].ID} + require.ElementsMatch(t, []uint{host1.ID, host2.ID, host3.ID}, got) + + // unknown software_title_id + hosts = listHostsCheckCount(t, ds, filter, fleet.HostListOptions{SoftwareTitleIDFilter: ptr.Uint(fooTitleID + 100)}, 0) + require.Len(t, hosts, 0) } func testHostsListBySoftwareChangedAt(t *testing.T, ds *Datastore) { diff --git a/server/fleet/hosts.go b/server/fleet/hosts.go index b82dcb6fc0..51b6e46f23 100644 --- a/server/fleet/hosts.go +++ b/server/fleet/hosts.go @@ -126,7 +126,17 @@ type HostListOptions struct { PolicyIDFilter *uint PolicyResponseFilter *bool + // Deprecated: SoftwareIDFilter is deprecated as of Fleet 4.42. It is + // maintained for backwards compatibility. Use SoftwareVersionIDFilter + // instead. SoftwareIDFilter *uint + // SoftwareVersionIDFilter filters the hosts by the software version ID that + // they use. This identifies a specific version of a "software title". + SoftwareVersionIDFilter *uint + // SoftwareTitleIDFilter filers the hosts by the software title ID that they + // use. This identifies a "software title" independent of the specific + // version. + SoftwareTitleIDFilter *uint OSIDFilter *uint OSNameFilter *string @@ -179,6 +189,8 @@ func (h HostListOptions) Empty() bool { h.PolicyIDFilter == nil && h.PolicyResponseFilter == nil && h.SoftwareIDFilter == nil && + h.SoftwareVersionIDFilter == nil && + h.SoftwareTitleIDFilter == nil && h.OSIDFilter == nil && h.OSNameFilter == nil && h.OSVersionFilter == nil && diff --git a/server/service/hosts.go b/server/service/hosts.go index 6c5f00be92..0b828fb4b3 100644 --- a/server/service/hosts.go +++ b/server/service/hosts.go @@ -54,8 +54,17 @@ type listHostsRequest struct { } type listHostsResponse struct { - Hosts []fleet.HostResponse `json:"hosts"` - Software *fleet.Software `json:"software,omitempty"` + Hosts []fleet.HostResponse `json:"hosts"` + // Software is populated with the software version corresponding to the + // software_version_id (or software_id) filter if one is provided with the + // request (and it exists in the database). It is nil otherwise and absent of + // the JSON response payload. + Software *fleet.Software `json:"software,omitempty"` + // SoftwareTitle is populated with the title corresponding to the + // software_title_id filter if one is provided with the request (and it + // exists in the database). It is nil otherwise and absent of the JSON + // response payload. + SoftwareTitle *fleet.SoftwareTitle `json:"software_title,omitempty"` // MDMSolution is populated with the MDM solution corresponding to the mdm_id // filter if one is provided with the request (and it exists in the // database). It is nil otherwise and absent of the JSON response payload. @@ -75,9 +84,24 @@ func listHostsEndpoint(ctx context.Context, request interface{}, svc fleet.Servi req := request.(*listHostsRequest) var software *fleet.Software - if req.Opts.SoftwareIDFilter != nil { + if req.Opts.SoftwareVersionIDFilter != nil || req.Opts.SoftwareIDFilter != nil { var err error - software, err = svc.SoftwareByID(ctx, *req.Opts.SoftwareIDFilter, false) + + id := req.Opts.SoftwareVersionIDFilter + if id == nil { + id = req.Opts.SoftwareIDFilter + } + software, err = svc.SoftwareByID(ctx, *id, false) + if err != nil { + return listHostsResponse{Err: err}, nil + } + } + + var softwareTitle *fleet.SoftwareTitle + if req.Opts.SoftwareTitleIDFilter != nil { + var err error + + softwareTitle, err = svc.SoftwareTitleByID(ctx, *req.Opts.SoftwareTitleIDFilter) if err != nil { return listHostsResponse{Err: err}, nil } @@ -112,10 +136,11 @@ func listHostsEndpoint(ctx context.Context, request interface{}, svc fleet.Servi hostResponses[i] = *h } return listHostsResponse{ - Hosts: hostResponses, - Software: software, - MDMSolution: mdmSolution, - MunkiIssue: munkiIssue, + Hosts: hostResponses, + Software: software, + SoftwareTitle: softwareTitle, + MDMSolution: mdmSolution, + MunkiIssue: munkiIssue, }, nil } diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index a7cc205839..6f3958fec3 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -1363,20 +1363,119 @@ func (s *integrationTestSuite) TestListHosts() { require.Len(t, resp.Hosts, len(hosts)-2) time.Sleep(1 * time.Second) - host := hosts[2] + + // create some software for various hosts + host2 := hosts[2] software := []fleet.Software{ {Name: "foo", Version: "0.0.1", Source: "chrome_extensions"}, } - _, err := s.ds.UpdateHostSoftware(context.Background(), host.ID, software) + _, err := s.ds.UpdateHostSoftware(context.Background(), host2.ID, software) require.NoError(t, err) - require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host, false)) + require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host2, false)) + host1 := hosts[1] + software = []fleet.Software{ + {Name: "foo", Version: "0.0.2", Source: "chrome_extensions"}, + {Name: "bar", Version: "0.1.0", Source: "application"}, + } + _, err = s.ds.UpdateHostSoftware(context.Background(), host1.ID, software) + require.NoError(t, err) + require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host1, false)) + + host0 := hosts[0] + software = []fleet.Software{ + {Name: "foo", Version: "0.0.2", Source: "chrome_extensions"}, + {Name: "bar", Version: "0.2.0", Source: "not_application"}, + } + _, err = s.ds.UpdateHostSoftware(context.Background(), host0.ID, software) + require.NoError(t, err) + require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host0, false)) + + err = s.ds.SyncHostsSoftware(context.Background(), time.Now()) + require.NoError(t, err) + err = s.ds.ReconcileSoftwareTitles(context.Background()) + require.NoError(t, err) + + var fooV1ID, fooV2ID, barAppTitleID, fooTitleID uint + mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error { + err := sqlx.GetContext(context.Background(), q, &fooV1ID, + `SELECT id FROM software WHERE name = ? AND source = ? AND version = ?`, "foo", "chrome_extensions", "0.0.1") + if err != nil { + return err + } + err = sqlx.GetContext(context.Background(), q, &fooV2ID, + `SELECT id FROM software WHERE name = ? AND source = ? AND version = ?`, "foo", "chrome_extensions", "0.0.2") + if err != nil { + return err + } + err = sqlx.GetContext(context.Background(), q, &barAppTitleID, + `SELECT id FROM software_titles WHERE name = ? AND source = ?`, "bar", "application") + if err != nil { + return err + } + err = sqlx.GetContext(context.Background(), q, &fooTitleID, + `SELECT id FROM software_titles WHERE name = ? AND source = ?`, "foo", "chrome_extensions") + if err != nil { + return err + } + return nil + }) + + // foo v0.0.1 is only installed on host2 resp = listHostsResponse{} - s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_id", fmt.Sprint(host.Software[0].ID)) + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_id", fmt.Sprint(fooV1ID)) require.Len(t, resp.Hosts, 1) - assert.Equal(t, host.ID, resp.Hosts[0].ID) + assert.Equal(t, host2.ID, resp.Hosts[0].ID) assert.Equal(t, "foo", resp.Software.Name) assert.Greater(t, resp.Hosts[0].SoftwareUpdatedAt, resp.Hosts[0].CreatedAt) + assert.Nil(t, resp.SoftwareTitle) + + var countResp countHostsResponse + s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusOK, &countResp, "software_id", fmt.Sprint(fooV1ID)) + require.Equal(t, 1, countResp.Count) + + // foo v0.0.2 is installed on hosts 0 and 1 + resp = listHostsResponse{} + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_version_id", fmt.Sprint(fooV2ID)) + require.Len(t, resp.Hosts, 2) + require.ElementsMatch(t, []uint{host0.ID, host1.ID}, []uint{resp.Hosts[0].ID, resp.Hosts[1].ID}) + + s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusOK, &countResp, "software_version_id", fmt.Sprint(fooV2ID)) + require.Equal(t, 2, countResp.Count) + + // bar/application title is only on host1 + resp = listHostsResponse{} + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_title_id", fmt.Sprint(barAppTitleID)) + require.Len(t, resp.Hosts, 1) + require.ElementsMatch(t, []uint{host1.ID}, []uint{resp.Hosts[0].ID}) + assert.Equal(t, "bar", resp.SoftwareTitle.Name) + assert.Equal(t, "application", resp.SoftwareTitle.Source) + assert.Equal(t, uint(1), resp.SoftwareTitle.HostsCount) + require.Len(t, resp.SoftwareTitle.Versions, 1) + assert.Equal(t, "0.1.0", resp.SoftwareTitle.Versions[0].Version) + assert.Nil(t, resp.Software) + + s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusOK, &countResp, "software_title_id", fmt.Sprint(barAppTitleID)) + require.Equal(t, 1, countResp.Count) + + // foo title is on all 3 hosts + resp = listHostsResponse{} + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_title_id", fmt.Sprint(fooTitleID)) + require.Len(t, resp.Hosts, 3) + require.ElementsMatch(t, []uint{host0.ID, host1.ID, host2.ID}, []uint{resp.Hosts[0].ID, resp.Hosts[1].ID, resp.Hosts[2].ID}) + + s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusOK, &countResp, "software_title_id", fmt.Sprint(fooTitleID)) + require.Equal(t, 3, countResp.Count) + + // verify invalid combinations of software filters + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusBadRequest, &resp, "software_title_id", fmt.Sprint(fooTitleID), "software_id", fmt.Sprint(fooV1ID)) + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusBadRequest, &resp, "software_title_id", fmt.Sprint(fooTitleID), "software_version_id", fmt.Sprint(fooV1ID)) + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusBadRequest, &resp, "software_id", fmt.Sprint(fooV1ID), "software_version_id", fmt.Sprint(fooV1ID)) + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusBadRequest, &resp, "software_id", fmt.Sprint(fooV1ID), "software_version_id", fmt.Sprint(fooV1ID), "software_title_id", fmt.Sprint(fooTitleID)) + s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusBadRequest, &countResp, "software_title_id", fmt.Sprint(fooTitleID), "software_id", fmt.Sprint(fooV1ID)) + s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusBadRequest, &countResp, "software_title_id", fmt.Sprint(fooTitleID), "software_version_id", fmt.Sprint(fooV1ID)) + s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusBadRequest, &countResp, "software_id", fmt.Sprint(fooV1ID), "software_version_id", fmt.Sprint(fooV1ID)) + s.DoJSON("GET", "/api/latest/fleet/hosts/count", nil, http.StatusBadRequest, &countResp, "software_id", fmt.Sprint(fooV1ID), "software_version_id", fmt.Sprint(fooV1ID), "software_title_id", fmt.Sprint(fooTitleID)) user1 := test.NewUser(t, s.ds, "Alice", "alice@example.com", true) q := test.NewQuery(t, s.ds, nil, "query1", "select 1", 0, true) @@ -1386,16 +1485,16 @@ func (s *integrationTestSuite) TestListHosts() { }) require.NoError(t, err) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false)) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), host2, map[uint]*bool{p.ID: ptr.Bool(false)}, time.Now(), false)) resp = listHostsResponse{} - s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_id", fmt.Sprint(host.Software[0].ID)) + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_id", fmt.Sprint(fooV1ID)) require.Len(t, resp.Hosts, 1) assert.Equal(t, 1, resp.Hosts[0].HostIssues.FailingPoliciesCount) assert.Equal(t, 1, resp.Hosts[0].HostIssues.TotalIssuesCount) resp = listHostsResponse{} - s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_id", fmt.Sprint(host.Software[0].ID), "disable_failing_policies", "true") + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusOK, &resp, "software_version_id", fmt.Sprint(fooV1ID), "disable_failing_policies", "true") require.Len(t, resp.Hosts, 1) assert.Equal(t, 0, resp.Hosts[0].HostIssues.FailingPoliciesCount) assert.Equal(t, 0, resp.Hosts[0].HostIssues.TotalIssuesCount) @@ -1422,7 +1521,7 @@ func (s *integrationTestSuite) TestListHosts() { assert.Nil(t, resp.MunkiIssue) // set MDM information on a host - require.NoError(t, s.ds.SetOrUpdateMDMData(context.Background(), host.ID, false, true, "https://simplemdm.com", false, fleet.WellKnownMDMSimpleMDM)) + require.NoError(t, s.ds.SetOrUpdateMDMData(context.Background(), host2.ID, false, true, "https://simplemdm.com", false, fleet.WellKnownMDMSimpleMDM)) var mdmID uint mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error { return sqlx.GetContext(context.Background(), q, &mdmID, @@ -1500,13 +1599,17 @@ func (s *integrationTestSuite) TestListHosts() { // Filter by inexistent software. resp = listHostsResponse{} s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusNotFound, &resp, "software_id", fmt.Sprint(9999)) + resp = listHostsResponse{} + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusNotFound, &resp, "software_version_id", fmt.Sprint(9999)) + resp = listHostsResponse{} + s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusNotFound, &resp, "software_title_id", fmt.Sprint(9999)) // Filter by non-existent team. resp = listHostsResponse{} s.DoJSON("GET", "/api/latest/fleet/hosts", nil, http.StatusBadRequest, &resp, "team_id", fmt.Sprint(9999)) // set munki information on a host - require.NoError(t, s.ds.SetOrUpdateMunkiInfo(context.Background(), host.ID, "1.2.3", []string{"err"}, []string{"warn"})) + require.NoError(t, s.ds.SetOrUpdateMunkiInfo(context.Background(), host2.ID, "1.2.3", []string{"err"}, []string{"warn"})) var errMunkiID uint mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error { return sqlx.GetContext(context.Background(), q, &errMunkiID, @@ -1537,7 +1640,7 @@ func (s *integrationTestSuite) TestListHosts() { // set operating system information on a host testOS := fleet.OperatingSystem{Name: "fooOS", Version: "4.2", Arch: "64bit", KernelVersion: "13.37", Platform: "bar"} - require.NoError(t, s.ds.UpdateHostOperatingSystem(context.Background(), host.ID, testOS)) + require.NoError(t, s.ds.UpdateHostOperatingSystem(context.Background(), host2.ID, testOS)) var osID uint mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error { return sqlx.GetContext(context.Background(), q, &osID, @@ -6567,6 +6670,32 @@ func (s *integrationTestSuite) TestHostsReportDownload() { require.NoError(t, s.ds.SetOrUpdateHostDisksSpace(ctx, hosts[0].ID, 1.0, 2.0)) require.NoError(t, s.ds.SetOrUpdateHostDisksSpace(ctx, hosts[1].ID, 3.0, 4.0)) + // create software for host [0] + software := []fleet.Software{ + {Name: "foo", Version: "0.0.1", Source: "chrome_extensions"}, + } + _, err = s.ds.UpdateHostSoftware(ctx, hosts[0].ID, software) + require.NoError(t, err) + require.NoError(t, s.ds.LoadHostSoftware(ctx, hosts[0], false)) + + err = s.ds.ReconcileSoftwareTitles(ctx) + require.NoError(t, err) + + var fooV1ID, fooTitleID uint + mysql.ExecAdhocSQL(t, s.ds, func(q sqlx.ExtContext) error { + err := sqlx.GetContext(context.Background(), q, &fooV1ID, + `SELECT id FROM software WHERE name = ? AND source = ? AND version = ?`, "foo", "chrome_extensions", "0.0.1") + if err != nil { + return err + } + err = sqlx.GetContext(context.Background(), q, &fooTitleID, + `SELECT id FROM software_titles WHERE name = ? AND source = ?`, "foo", "chrome_extensions") + if err != nil { + return err + } + return nil + }) + res := s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusUnsupportedMediaType, "format", "gzip") var errs validationErrResp require.NoError(t, json.NewDecoder(res.Body).Decode(&errs)) @@ -6654,6 +6783,22 @@ func (s *integrationTestSuite) TestHostsReportDownload() { require.Len(t, rows, 2) // headers + member host require.Contains(t, rows[1], hosts[2].Hostname) + // with a software version id + res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, "format", "csv", "columns", "hostname", "software_version_id", fmt.Sprint(fooV1ID)) + rows, err = csv.NewReader(res.Body).ReadAll() + res.Body.Close() + require.NoError(t, err) + require.Len(t, rows, 2) // headers + member host + require.Contains(t, rows[1], hosts[0].Hostname) + + // with a software title id + res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusOK, "format", "csv", "columns", "hostname", "software_title_id", fmt.Sprint(fooTitleID)) + rows, err = csv.NewReader(res.Body).ReadAll() + res.Body.Close() + require.NoError(t, err) + require.Len(t, rows, 2) // headers + member host + require.Contains(t, rows[1], hosts[0].Hostname) + // valid format but an invalid column is provided res = s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusBadRequest, "format", "csv", "columns", "memory,hostname,status,nosuchcolumn") require.NoError(t, json.NewDecoder(res.Body).Decode(&errs)) @@ -6675,6 +6820,12 @@ func (s *integrationTestSuite) TestHostsReportDownload() { require.Equal(t, []string{"0", "TestIntegrations/TestHostsReportDownloadfoo.local1"}, rows[2][:2]) require.Len(t, rows[3], 3) require.Equal(t, []string{"0", "TestIntegrations/TestHostsReportDownloadfoo.local0"}, rows[3][:2]) + + // invalid combinations of software filters + s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusBadRequest, "software_title_id", "123", "software_id", "456") + s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusBadRequest, "software_title_id", "123", "software_version_id", "456") + s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusBadRequest, "software_id", "123", "software_version_id", "456") + s.DoRaw("GET", "/api/latest/fleet/hosts/report", nil, http.StatusBadRequest, "software_id", "123", "software_version_id", "456", "software_title_id", "789") } func (s *integrationTestSuite) TestSSODisabled() { diff --git a/server/service/transport.go b/server/service/transport.go index 53088950c8..37b8989497 100644 --- a/server/service/transport.go +++ b/server/service/transport.go @@ -217,7 +217,7 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error) teamID := r.URL.Query().Get("team_id") if teamID != "" { - id, err := strconv.Atoi(teamID) + id, err := strconv.ParseUint(teamID, 10, 32) if err != nil { return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid team_id: %s", teamID))) } @@ -227,7 +227,7 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error) policyID := r.URL.Query().Get("policy_id") if policyID != "" { - id, err := strconv.Atoi(policyID) + id, err := strconv.ParseUint(policyID, 10, 32) if err != nil { return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid policy_id: %s", policyID))) } @@ -266,7 +266,7 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error) softwareID := r.URL.Query().Get("software_id") if softwareID != "" { - id, err := strconv.Atoi(softwareID) + id, err := strconv.ParseUint(softwareID, 10, 64) if err != nil { return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid software_id: %s", softwareID))) } @@ -274,9 +274,29 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error) hopt.SoftwareIDFilter = &sid } + softwareVersionID := r.URL.Query().Get("software_version_id") + if softwareVersionID != "" { + id, err := strconv.ParseUint(softwareVersionID, 10, 64) + if err != nil { + return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid software_version_id: %s", softwareVersionID))) + } + sid := uint(id) + hopt.SoftwareVersionIDFilter = &sid + } + + softwareTitleID := r.URL.Query().Get("software_title_id") + if softwareTitleID != "" { + id, err := strconv.ParseUint(softwareTitleID, 10, 32) + if err != nil { + return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid software_title_id: %s", softwareTitleID))) + } + sid := uint(id) + hopt.SoftwareTitleIDFilter = &sid + } + osID := r.URL.Query().Get("os_id") if osID != "" { - id, err := strconv.Atoi(osID) + id, err := strconv.ParseUint(osID, 10, 32) if err != nil { return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid os_id: %s", osID))) } @@ -336,7 +356,7 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error) mdmID := r.URL.Query().Get("mdm_id") if mdmID != "" { - id, err := strconv.Atoi(mdmID) + id, err := strconv.ParseUint(mdmID, 10, 32) if err != nil { return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid mdm_id: %s", mdmID))) } @@ -438,7 +458,7 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error) munkiIssueID := r.URL.Query().Get("munki_issue_id") if munkiIssueID != "" { - id, err := strconv.Atoi(munkiIssueID) + id, err := strconv.ParseUint(munkiIssueID, 10, 32) if err != nil { return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid munki_issue_id: %s", munkiIssueID))) } @@ -464,6 +484,21 @@ func hostListOptionsFromRequest(r *http.Request) (fleet.HostListOptions, error) hopt.LowDiskSpaceFilter = &v } + // cannot combine software_id, software_version_id, and software_title_id + var softwareErrorLabel []string + if hopt.SoftwareIDFilter != nil { + softwareErrorLabel = append(softwareErrorLabel, "software_id") + } + if hopt.SoftwareVersionIDFilter != nil { + softwareErrorLabel = append(softwareErrorLabel, "software_version_id") + } + if hopt.SoftwareTitleIDFilter != nil { + softwareErrorLabel = append(softwareErrorLabel, "software_title_id") + } + if len(softwareErrorLabel) > 1 { + return hopt, ctxerr.Wrap(r.Context(), badRequest(fmt.Sprintf("Invalid parameters. The combination of %s is not allowed.", strings.Join(softwareErrorLabel, " and ")))) + } + return hopt, nil } diff --git a/server/service/transport_test.go b/server/service/transport_test.go index cc4504e801..c167573d83 100644 --- a/server/service/transport_test.go +++ b/server/service/transport_test.go @@ -1,12 +1,14 @@ package service import ( - "github.com/fleetdm/fleet/v4/server/ptr" + "fmt" "net/http" "net/url" "strings" "testing" + "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/fleetdm/fleet/v4/server/fleet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -291,6 +293,48 @@ func TestHostListOptionsFromRequest(t *testing.T) { url: "/foo?os_name=foo", errorMessage: "Invalid os_version", }, + "negative software_id": { + url: "/foo?software_id=-10", + errorMessage: "Invalid software_id", + }, + "negative software_version_id": { + url: "/foo?software_version_id=-10", + errorMessage: "Invalid software_version_id", + }, + "negative software_title_id": { + url: "/foo?software_title_id=-10", + errorMessage: "Invalid software_title_id", + }, + "software_title_id too big": { + url: "/foo?software_title_id=" + fmt.Sprint(1<<33), + errorMessage: "Invalid software_title_id", + }, + "software_version_id can be > 32bits": { + url: "/foo?software_version_id=" + fmt.Sprint(1<<33), + hostListOptions: fleet.HostListOptions{ + SoftwareVersionIDFilter: ptr.Uint(1 << 33), + }, + }, + "good software_version_id": { + url: "/foo?software_version_id=1", + hostListOptions: fleet.HostListOptions{ + SoftwareVersionIDFilter: ptr.Uint(1), + }, + }, + "good software_title_id": { + url: "/foo?software_title_id=1", + hostListOptions: fleet.HostListOptions{ + SoftwareTitleIDFilter: ptr.Uint(1), + }, + }, + "invalid combination software_title_id and software_version_id": { + url: "/foo?software_title_id=1&software_version_id=2", + errorMessage: "The combination of software_version_id and software_title_id is not allowed", + }, + "invalid combination software_id and software_version_id": { + url: "/foo?software_id=1&software_version_id=2", + errorMessage: "The combination of software_id and software_version_id is not allowed", + }, } for name, tt := range hostListOptionsTests { @@ -304,10 +348,7 @@ func TestHostListOptionsFromRequest(t *testing.T) { assert.NotNil(t, err) var be *fleet.BadRequestError require.ErrorAs(t, err, &be) - assert.True( - t, strings.Contains(err.Error(), tt.errorMessage), - "error message '%v' should contain '%v'", err.Error(), tt.errorMessage, - ) + require.Contains(t, err.Error(), tt.errorMessage) return } assert.Nil(t, err)