From aada28c1c16cb490bd13a1bf7900716e53894d77 Mon Sep 17 00:00:00 2001 From: Roberto Dip Date: Wed, 6 Dec 2023 15:28:31 -0300 Subject: [PATCH] Add list/detail endpoints for software titles (#15464) related to #15228 --- server/datastore/mysql/software_titles.go | 199 ++++++++ server/fleet/datastore.go | 6 + server/fleet/service.go | 6 + server/fleet/software.go | 46 ++ server/mock/datastore_mock.go | 24 + server/service/handler.go | 3 + server/service/integration_core_test.go | 15 + server/service/integration_enterprise_test.go | 427 ++++++++++++++++++ server/service/software_titles.go | 126 ++++++ server/service/software_titles_test.go | 161 +++++++ 10 files changed, 1013 insertions(+) create mode 100644 server/datastore/mysql/software_titles.go create mode 100644 server/service/software_titles.go create mode 100644 server/service/software_titles_test.go diff --git a/server/datastore/mysql/software_titles.go b/server/datastore/mysql/software_titles.go new file mode 100644 index 0000000000..18ac8aac34 --- /dev/null +++ b/server/datastore/mysql/software_titles.go @@ -0,0 +1,199 @@ +package mysql + +import ( + "context" + "database/sql" + "fmt" + + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/jmoiron/sqlx" +) + +func (ds *Datastore) SoftwareTitleByID(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) { + const selectSoftwareTitleStmt = ` +SELECT + st.id, + st.name, + st.source, + COUNT(DISTINCT hs.host_id) AS hosts_count, + COUNT(DISTINCT s.id) AS versions_count +FROM software_titles st +JOIN software s ON s.title_id = st.id +JOIN host_software hs ON hs.software_id = s.id +WHERE st.id = ? +GROUP BY st.id + ` + var title fleet.SoftwareTitle + if err := sqlx.GetContext(ctx, ds.reader(ctx), &title, selectSoftwareTitleStmt, id); err != nil { + if err == sql.ErrNoRows { + return nil, notFound("SoftwareTitle").WithID(id) + } + return nil, ctxerr.Wrap(ctx, err, "get software title") + } + + selectSoftwareVersionsStmt, args, err := selectSoftwareVersionsSQL([]uint{id}, 0, true) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "building versions statement") + } + + if err := sqlx.SelectContext(ctx, ds.reader(ctx), &title.Versions, selectSoftwareVersionsStmt, args...); err != nil { + return nil, ctxerr.Wrap(ctx, err, "get software title version") + } + + return &title, nil +} + +func (ds *Datastore) ListSoftwareTitles( + ctx context.Context, + opt fleet.SoftwareTitleListOptions, +) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error) { + dbReader := ds.reader(ctx) + getTitlesStmt, args := selectSoftwareTitlesSQL(opt) + // build the count statement before adding the pagination constraints to `getTitlesStmt` + getTitlesCountStmt := fmt.Sprintf(`SELECT COUNT(DISTINCT s.id) FROM (%s) AS s`, getTitlesStmt) + + // grab titles that match the list options + var titles []fleet.SoftwareTitle + getTitlesStmt, args = appendListOptionsWithCursorToSQL(getTitlesStmt, args, &opt.ListOptions) + if err := sqlx.SelectContext(ctx, dbReader, &titles, getTitlesStmt, args...); err != nil { + return nil, 0, nil, ctxerr.Wrap(ctx, err, "select software titles") + } + + // perform a second query to grab the counts + var counts int + if err := sqlx.GetContext(ctx, dbReader, &counts, getTitlesCountStmt, args...); err != nil { + return nil, 0, nil, ctxerr.Wrap(ctx, err, "get software titles count") + } + + // if we don't have any matching titles, there's no point trying to + // find matching versions. Early return + if len(titles) == 0 { + return titles, counts, &fleet.PaginationMetadata{}, nil + } + + // grab all the IDs to find matching versions below + titleIDs := make([]uint, len(titles)) + // build an index to quickly access a title by it's ID + titleIndex := make(map[uint]int, len(titles)) + for i, title := range titles { + titleIDs[i] = title.ID + titleIndex[title.ID] = i + } + + // we grab matching versions separately and build the desired object in + // the application logic. This is because we need to support MySQL 5.7 + // and there's no good way to do an aggregation that builds a structure + // (like a JSON) object for nested arrays. + var teamID uint + if opt.TeamID != nil { + teamID = *opt.TeamID + } + getVersionsStmt, args, err := selectSoftwareVersionsSQL(titleIDs, teamID, false) + if err != nil { + return nil, 0, nil, ctxerr.Wrap(ctx, err, "build get versions stmt") + } + var versions []fleet.SoftwareVersion + if err := sqlx.SelectContext(ctx, dbReader, &versions, getVersionsStmt, args...); err != nil { + return nil, 0, nil, ctxerr.Wrap(ctx, err, "get software versions") + } + + // append matching versions to titles + for _, version := range versions { + if i, ok := titleIndex[version.TitleID]; ok { + titles[i].Versions = append(titles[i].Versions, version) + } + } + + var metaData *fleet.PaginationMetadata + if opt.ListOptions.IncludeMetadata { + metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.ListOptions.Page > 0} + if len(titles) > int(opt.ListOptions.PerPage) { + metaData.HasNextResults = true + titles = titles[:len(titles)-1] + } + } + + return titles, counts, metaData, nil +} + +func selectSoftwareTitlesSQL(opt fleet.SoftwareTitleListOptions) (string, []any) { + stmt := ` +SELECT + st.id, + st.name, + st.source, + COUNT(DISTINCT hs.host_id) AS hosts_count, + COUNT(DISTINCT s.id) AS versions_count +FROM software_titles st +JOIN software s ON s.title_id = st.id +JOIN host_software hs ON hs.software_id = s.id +-- placeholder for changing the JOIN type to filter vulnerable software +%s JOIN software_cve scve ON s.id = scve.software_id +-- placeholder for potential JOIN on hosts +%s +-- placeholder for WHERE clause +WHERE %s +GROUP BY st.id` + + cveJoinType := "LEFT" + if opt.VulnerableOnly { + cveJoinType = "INNER" + } + + var args []any + hostsJoin := "" + whereClause := "TRUE" + if opt.TeamID != nil { + hostsJoin = "JOIN hosts h ON h.id = hs.host_id" + whereClause = "h.team_id = ?" + args = append(args, opt.TeamID) + } + + if match := opt.ListOptions.MatchQuery; match != "" { + whereClause += " AND (st.name LIKE ? OR scve.cve LIKE ?)" + match = likePattern(match) + args = append(args, match, match) + } + + stmt = fmt.Sprintf(stmt, cveJoinType, hostsJoin, whereClause) + return stmt, args +} + +func selectSoftwareVersionsSQL(titleIDs []uint, teamID uint, withCounts bool) (string, []any, error) { + selectVersionsStmt := ` +SELECT + st.id as title_id, + s.id, s.version, + %s -- placeholder for optional host_counts + CONCAT('[', GROUP_CONCAT(JSON_QUOTE(scve.cve) SEPARATOR ','), ']') as vulnerabilities +FROM software_titles st +JOIN software s ON s.title_id = st.id +LEFT JOIN host_software hs ON hs.software_id = s.id +LEFT JOIN software_cve scve ON s.id = scve.software_id +%s -- placeholder for optional JOIN ON host_counts +WHERE st.id IN (?) +GROUP BY s.id` + + var args []any + extraSelect := "" + extraJoin := "" + if withCounts { + args = append(args, teamID) + extraSelect = "MAX(shc.hosts_count) AS hosts_count," + extraJoin = ` + JOIN software_host_counts shc + ON shc.software_id = s.id + AND shc.hosts_count > 0 + AND shc.team_id = ? + ` + } + + args = append(args, titleIDs) + selectVersionsStmt = fmt.Sprintf(selectVersionsStmt, extraSelect, extraJoin) + selectVersionsStmt, args, err := sqlx.In(selectVersionsStmt, args...) + if err != nil { + return "", nil, fmt.Errorf("bulding sqlx.In query: %w", err) + } + return selectVersionsStmt, args, nil +} diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 35c8528bb2..aae314fe7b 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -429,6 +429,12 @@ type Datastore interface { // are being deleted from the global configuration. DeleteIntegrationsFromTeams(ctx context.Context, deletedIntgs Integrations) error + /////////////////////////////////////////////////////////////////////////////// + // Software Titles + + ListSoftwareTitles(ctx context.Context, opt SoftwareTitleListOptions) ([]SoftwareTitle, int, *PaginationMetadata, error) + SoftwareTitleByID(ctx context.Context, id uint) (*SoftwareTitle, error) + /////////////////////////////////////////////////////////////////////////////// // SoftwareStore diff --git a/server/fleet/service.go b/server/fleet/service.go index 93e988a895..4f670c0753 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -567,6 +567,12 @@ type Service interface { SoftwareByID(ctx context.Context, id uint, includeCVEScores bool) (*Software, error) CountSoftware(ctx context.Context, opt SoftwareListOptions) (int, error) + // ///////////////////////////////////////////////////////////////////////////// + // Software Titles + + ListSoftwareTitles(ctx context.Context, opt SoftwareTitleListOptions) ([]SoftwareTitle, int, *PaginationMetadata, error) + SoftwareTitleByID(ctx context.Context, id uint) (*SoftwareTitle, error) + // ///////////////////////////////////////////////////////////////////////////// // Team Policies diff --git a/server/fleet/software.go b/server/fleet/software.go index 701a958798..7730dab201 100644 --- a/server/fleet/software.go +++ b/server/fleet/software.go @@ -1,6 +1,7 @@ package fleet import ( + "encoding/json" "errors" "fmt" "strconv" @@ -102,12 +103,57 @@ func (s Software) ToUniqueStr() string { return strings.Join(ss, SoftwareFieldSeparator) } +type SliceString []string + +func (c *SliceString) Scan(v interface{}) error { + switch tv := v.(type) { + case []byte: + return json.Unmarshal(tv, &c) + } + return errors.New("unsupported type") +} + +// SoftwareVersion is an abstraction over the `software` table to support the +// software titles APIs +type SoftwareVersion struct { + ID uint `db:"id" json:"id"` + // Version is the version string we grab for this specific software. + Version string `db:"version" json:"version"` + // Vulnerabilities is the list of CVE names for vulnerabilities found for this version. + Vulnerabilities *SliceString `db:"vulnerabilities" json:"vulnerabilities,omitempty"` + // HostsCount is the number of hosts that use this software version. + HostsCount *uint `db:"hosts_count" json:"hosts_count,omitempty"` + + // TitleID is used only as an auxiliary field and it's not part of the + // JSON response. + TitleID uint `db:"title_id" json:"-"` +} + +// SoftwareTitle represents a title backed by the `software_titles` table. type SoftwareTitle struct { ID uint `json:"id" db:"id"` // Name is the name reported by osquery. Name string `json:"name" db:"name"` // Source is the source reported by osquery. Source string `json:"source" db:"source"` + // HostsCount is the number of hosts that use this software title. + HostsCount uint `json:"hosts_count" db:"hosts_count"` + // VesionsCount is the number of versions that have the same title. + VersionsCount uint `json:"versions_count" db:"versions_count"` + // Versions countains information about the versions that use this title. + Versions []SoftwareVersion `json:"versions" db:"-"` + // CountsUpdatedAt is the timestamp when the hosts count + // was last updated for that software, filled only if hosts + // count is requested. + CountsUpdatedAt time.Time `json:"-" db:"counts_updated_at"` +} + +type SoftwareTitleListOptions struct { + // ListOptions cannot be embedded in order to unmarshall with validation. + ListOptions ListOptions `url:"list_options"` + + TeamID *uint `query:"team_id,optional"` + VulnerableOnly bool `query:"vulnerable,optional"` } // AuthzSoftwareInventory is used for access controls on software inventory. diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index 01d0ce92fe..9c6cb8131d 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -324,6 +324,10 @@ type TeamEnrollSecretsFunc func(ctx context.Context, teamID uint) ([]*fleet.Enro type DeleteIntegrationsFromTeamsFunc func(ctx context.Context, deletedIntgs fleet.Integrations) error +type ListSoftwareTitlesFunc func(ctx context.Context, opt fleet.SoftwareTitleListOptions) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error) + +type SoftwareTitleByIDFunc func(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) + type ListSoftwareForVulnDetectionFunc func(ctx context.Context, hostID uint) ([]fleet.Software, error) type ListSoftwareVulnerabilitiesByHostIDsSourceFunc func(ctx context.Context, hostIDs []uint, source fleet.VulnerabilitySource) (map[uint][]fleet.SoftwareVulnerability, error) @@ -1224,6 +1228,12 @@ type DataStore struct { DeleteIntegrationsFromTeamsFunc DeleteIntegrationsFromTeamsFunc DeleteIntegrationsFromTeamsFuncInvoked bool + ListSoftwareTitlesFunc ListSoftwareTitlesFunc + ListSoftwareTitlesFuncInvoked bool + + SoftwareTitleByIDFunc SoftwareTitleByIDFunc + SoftwareTitleByIDFuncInvoked bool + ListSoftwareForVulnDetectionFunc ListSoftwareForVulnDetectionFunc ListSoftwareForVulnDetectionFuncInvoked bool @@ -2958,6 +2968,20 @@ func (s *DataStore) DeleteIntegrationsFromTeams(ctx context.Context, deletedIntg return s.DeleteIntegrationsFromTeamsFunc(ctx, deletedIntgs) } +func (s *DataStore) ListSoftwareTitles(ctx context.Context, opt fleet.SoftwareTitleListOptions) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error) { + s.mu.Lock() + s.ListSoftwareTitlesFuncInvoked = true + s.mu.Unlock() + return s.ListSoftwareTitlesFunc(ctx, opt) +} + +func (s *DataStore) SoftwareTitleByID(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) { + s.mu.Lock() + s.SoftwareTitleByIDFuncInvoked = true + s.mu.Unlock() + return s.SoftwareTitleByIDFunc(ctx, id) +} + func (s *DataStore) ListSoftwareForVulnDetection(ctx context.Context, hostID uint) ([]fleet.Software, error) { s.mu.Lock() s.ListSoftwareForVulnDetectionFuncInvoked = true diff --git a/server/service/handler.go b/server/service/handler.go index 4274c784d8..282cb11c8a 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -368,6 +368,9 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC // DEPRECATED: software version counts are now included directly in the software version list ue.GET("/api/_version_/fleet/software/count", countSoftwareEndpoint, countSoftwareRequest{}) + ue.GET("/api/_version_/fleet/software/titles", listSoftwareTitlesEndpoint, listSoftwareTitlesRequest{}) + ue.GET("/api/_version_/fleet/software/titles/{id:[0-9]+}", getSoftwareTitleEndpoint, getSoftwareTitleRequest{}) + ue.GET("/api/_version_/fleet/host_summary", getHostSummaryEndpoint, getHostSummaryRequest{}) ue.GET("/api/_version_/fleet/hosts", listHostsEndpoint, listHostsRequest{}) ue.POST("/api/_version_/fleet/hosts/delete", deleteHostsEndpoint, deleteHostsRequest{}) diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index fd51e11696..e2af124c01 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -5005,6 +5005,21 @@ func (s *integrationTestSuite) TestPremiumEndpointsWithoutLicense() { // batch set scripts s.Do("POST", "/api/v1/fleet/scripts/batch", batchSetScriptsRequest{Scripts: nil}, http.StatusPaymentRequired) + + // software titles + // a normal request works fine + var resp listSoftwareTitlesResponse + s.DoJSON("GET", "/api/latest/fleet/software/titles", listSoftwareTitlesRequest{}, http.StatusOK, &resp) + require.Equal(t, 0, resp.Count) + require.Nil(t, resp.SoftwareTitles) + + // a request with a team_id parameter returns a license error + resp = listSoftwareTitlesResponse{} + s.DoJSON( + "GET", "/api/latest/fleet/software/titles", + listSoftwareTitlesRequest{}, http.StatusPaymentRequired, &resp, + "team_id", "1", + ) } // TestGlobalPoliciesBrowsing tests that team users can browse (read) global policies (see #3722). diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index 61d60cbab9..963fb954b8 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -3665,6 +3665,8 @@ func (s *integrationEnterpriseTestSuite) TestGitOpsUserActions() { s.DoJSON("GET", "/api/latest/fleet/software/versions", listSoftwareRequest{}, http.StatusForbidden, &listSoftwareVersionsResponse{}) s.DoJSON("GET", "/api/latest/fleet/software", listSoftwareRequest{}, http.StatusForbidden, &listSoftwareResponse{}) s.DoJSON("GET", "/api/latest/fleet/software/count", countSoftwareRequest{}, http.StatusForbidden, &countSoftwareResponse{}) + s.DoJSON("GET", "/api/latest/fleet/software/titles", listSoftwareTitlesRequest{}, http.StatusForbidden, &listSoftwareTitlesResponse{}) + s.DoJSON("GET", "/api/latest/fleet/software/titles/1", getSoftwareTitleRequest{}, http.StatusForbidden, &getSoftwareTitleResponse{}) // Attempt to list a software, should fail. s.DoJSON("GET", "/api/latest/fleet/software/1", getSoftwareRequest{}, http.StatusForbidden, &getSoftwareResponse{}) @@ -5689,6 +5691,431 @@ func (s *integrationEnterpriseTestSuite) TestTeamConfigDetailQueriesOverrides() require.Contains(t, dqResp.Queries, fmt.Sprintf("fleet_distributed_query_%s", t.Name())) } +func (s *integrationEnterpriseTestSuite) TestAllSoftwareTitles() { + ctx := context.Background() + t := s.T() + + softwareTitlesMatch := func(want, got []fleet.SoftwareTitle) { + // compare only the fields we care about + for i := range got { + require.NotZero(t, got[i].ID) + got[i].ID = 0 + + for j := range got[i].Versions { + require.NotZero(t, got[i].Versions[j].ID) + got[i].Versions[j].ID = 0 + } + } + + // sort and use EqualValues instead of ElementsMatch in order + // to do a deep comparison of nested structures + sort.Slice(got, func(i, j int) bool { + return got[i].Name < got[j].Name + }) + sort.Slice(want, func(i, j int) bool { + return want[i].Name < want[j].Name + }) + + require.EqualValues(t, want, got) + } + + host, err := s.ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now().Add(-1 * time.Minute), + OsqueryHostID: ptr.String(t.Name()), + NodeKey: ptr.String(t.Name()), + UUID: uuid.New().String(), + Hostname: fmt.Sprintf("%sfoo.local", t.Name()), + Platform: "darwin", + }) + require.NoError(t, err) + + tmHost, err := s.ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + PolicyUpdatedAt: time.Now(), + SeenTime: time.Now().Add(-1 * time.Minute), + OsqueryHostID: ptr.String(t.Name() + "tm"), + NodeKey: ptr.String(t.Name() + "tm"), + UUID: uuid.New().String(), + Hostname: fmt.Sprintf("%sfoo.local", t.Name()+"tm"), + Platform: "linux", + }) + require.NoError(t, err) + + // create a couple of teams and add tmHost to one + team1, err := s.ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team1"}) + require.NoError(t, err) + _, err = s.ds.NewTeam(ctx, &fleet.Team{Name: t.Name() + "team2"}) + require.NoError(t, err) + require.NoError(t, s.ds.AddHostsToTeam(ctx, &team1.ID, []uint{tmHost.ID})) + + software := []fleet.Software{ + {Name: "foo", Version: "0.0.1", Source: "homebrew"}, + {Name: "foo", Version: "0.0.3", Source: "homebrew"}, + {Name: "bar", Version: "0.0.4", Source: "apps"}, + } + _, err = s.ds.UpdateHostSoftware(context.Background(), host.ID, software) + require.NoError(t, err) + require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host, false)) + + soft1 := host.Software[0] + if soft1.Name != "bar" { + soft1 = host.Software[1] + } + + cpes := []fleet.SoftwareCPE{{SoftwareID: soft1.ID, CPE: "somecpe"}} + _, err = s.ds.UpsertSoftwareCPEs(context.Background(), cpes) + require.NoError(t, err) + + // Reload software so that 'GeneratedCPEID is set. + require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host, false)) + soft1 = host.Software[0] + if soft1.Name != "bar" { + soft1 = host.Software[1] + } + + inserted, err := s.ds.InsertSoftwareVulnerability( + context.Background(), fleet.SoftwareVulnerability{ + SoftwareID: soft1.ID, + CVE: "cve-123-123-132", + }, fleet.NVDSource, + ) + require.NoError(t, err) + require.True(t, inserted) + + // calculate hosts counts + hostsCountTs := time.Now().UTC() + require.NoError(t, s.ds.SyncHostsSoftware(context.Background(), hostsCountTs)) + require.NoError(t, s.ds.ReconcileSoftwareTitles(ctx)) + + t.Run("GET /software/titles", func(t *testing.T) { + var resp listSoftwareTitlesResponse + s.DoJSON("GET", "/api/latest/fleet/software/titles", listSoftwareTitlesRequest{}, http.StatusOK, &resp) + require.Equal(t, 2, resp.Count) + softwareTitlesMatch([]fleet.SoftwareTitle{ + { + Name: "foo", + Source: "homebrew", + VersionsCount: 2, + HostsCount: 1, + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.1", Vulnerabilities: nil}, + {Version: "0.0.3", Vulnerabilities: nil}, + }, + }, + { + Name: "bar", + Source: "apps", + VersionsCount: 1, + HostsCount: 1, + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.4", Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}}, + }, + }, + }, resp.SoftwareTitles) + + // per_page equals 1, so we get only one item, but the total count is + // still 2 + resp = listSoftwareTitlesResponse{} + s.DoJSON( + "GET", "/api/latest/fleet/software/titles", + listSoftwareTitlesRequest{}, + http.StatusOK, &resp, + "per_page", "1", + "order_key", "name", + "order_direction", "desc", + ) + require.Equal(t, 2, resp.Count) + softwareTitlesMatch([]fleet.SoftwareTitle{ + { + Name: "foo", + Source: "homebrew", + VersionsCount: 2, + HostsCount: 1, + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.1", Vulnerabilities: nil}, + {Version: "0.0.3", Vulnerabilities: nil}, + }, + }, + }, resp.SoftwareTitles) + + // get the second item + resp = listSoftwareTitlesResponse{} + s.DoJSON( + "GET", "/api/latest/fleet/software/titles", + listSoftwareTitlesRequest{}, + http.StatusOK, &resp, + "per_page", "1", + "page", "1", + "order_key", "name", + "order_direction", "desc", + ) + require.Equal(t, 2, resp.Count) + softwareTitlesMatch([]fleet.SoftwareTitle{ + { + Name: "bar", + Source: "apps", + VersionsCount: 1, + HostsCount: 1, + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.4", Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}}, + }, + }, + }, resp.SoftwareTitles) + + // asking for a non-existent page returns an empty list + resp = listSoftwareTitlesResponse{} + s.DoJSON( + "GET", "/api/latest/fleet/software/titles", + listSoftwareTitlesRequest{}, + http.StatusOK, &resp, + "per_page", "1", + "page", "4", + "order_key", "name", + "order_direction", "desc", + ) + require.Equal(t, 2, resp.Count) + softwareTitlesMatch(nil, resp.SoftwareTitles) + + // asking for vulnerable only software returns the expected values + resp = listSoftwareTitlesResponse{} + s.DoJSON( + "GET", "/api/latest/fleet/software/titles", + listSoftwareTitlesRequest{}, + http.StatusOK, &resp, + "vulnerable", "true", + ) + require.Equal(t, 1, resp.Count) + softwareTitlesMatch([]fleet.SoftwareTitle{ + { + Name: "bar", + Source: "apps", + VersionsCount: 1, + HostsCount: 1, + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.4", Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}}, + }, + }, + }, resp.SoftwareTitles) + + // request titles for team1, nothing there yet + resp = listSoftwareTitlesResponse{} + s.DoJSON( + "GET", "/api/latest/fleet/software/titles", + listSoftwareTitlesRequest{}, + http.StatusOK, &resp, + "team_id", "1", + ) + require.Equal(t, 0, resp.Count) + softwareTitlesMatch(nil, resp.SoftwareTitles) + + // add new software for tmHost + software = []fleet.Software{ + {Name: "foo", Version: "0.0.1", Source: "homebrew"}, + {Name: "baz", Version: "0.0.5", Source: "deb_packages"}, + } + _, err = s.ds.UpdateHostSoftware(context.Background(), tmHost.ID, software) + require.NoError(t, err) + + // calculate hosts counts + hostsCountTs := time.Now().UTC() + require.NoError(t, s.ds.SyncHostsSoftware(context.Background(), hostsCountTs)) + require.NoError(t, s.ds.ReconcileSoftwareTitles(ctx)) + + // request software for the team, this time we get results + resp = listSoftwareTitlesResponse{} + s.DoJSON( + "GET", "/api/latest/fleet/software/titles", + listSoftwareTitlesRequest{}, + http.StatusOK, &resp, + "team_id", "1", + "order_key", "name", + "order_direction", "desc", + ) + require.Equal(t, 2, resp.Count) + softwareTitlesMatch([]fleet.SoftwareTitle{ + { + Name: "baz", + Source: "deb_packages", + VersionsCount: 1, + HostsCount: 1, + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.5", Vulnerabilities: nil}, + }, + }, + { + Name: "foo", + Source: "homebrew", + VersionsCount: 1, // NOTE: this value is 1 because the team has only 1 matching host + HostsCount: 1, // NOTE: this value is 1 because the team has only 1 matching host + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.1", Vulnerabilities: nil}, + {Version: "0.0.3", Vulnerabilities: nil}, + }, + }, + }, resp.SoftwareTitles) + + // request software for no-team, we get all results and 2 hosts for + // `"foo"` + resp = listSoftwareTitlesResponse{} + s.DoJSON( + "GET", "/api/latest/fleet/software/titles", + listSoftwareTitlesRequest{}, + http.StatusOK, &resp, + "order_key", "name", + "order_direction", "desc", + ) + require.Equal(t, 3, resp.Count) + softwareTitlesMatch([]fleet.SoftwareTitle{ + { + Name: "baz", + Source: "deb_packages", + VersionsCount: 1, + HostsCount: 1, + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.5", Vulnerabilities: nil}, + }, + }, + { + Name: "foo", + Source: "homebrew", + VersionsCount: 2, // NOTE: this value is 2, important because no team filter was applied + HostsCount: 2, // NOTE: this value is 2, important because no team filter was applied + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.1", Vulnerabilities: nil}, + {Version: "0.0.3", Vulnerabilities: nil}, + }, + }, + { + Name: "bar", + Source: "apps", + VersionsCount: 1, + HostsCount: 1, + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.4", Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}}, + }, + }, + }, resp.SoftwareTitles) + + // match cve by name + resp = listSoftwareTitlesResponse{} + s.DoJSON( + "GET", "/api/latest/fleet/software/titles", + listSoftwareTitlesRequest{}, + http.StatusOK, &resp, + "query", "123", + ) + require.Equal(t, 1, resp.Count) + softwareTitlesMatch([]fleet.SoftwareTitle{ + { + Name: "bar", + Source: "apps", + VersionsCount: 1, + HostsCount: 1, + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.4", Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}}, + }, + }, + }, resp.SoftwareTitles) + + // match software title by name + resp = listSoftwareTitlesResponse{} + s.DoJSON( + "GET", "/api/latest/fleet/software/titles", + listSoftwareTitlesRequest{}, + http.StatusOK, &resp, + "query", "ba", + ) + require.Equal(t, 2, resp.Count) + softwareTitlesMatch([]fleet.SoftwareTitle{ + { + Name: "bar", + Source: "apps", + VersionsCount: 1, + HostsCount: 1, + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.4", Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}}, + }, + }, + { + Name: "baz", + Source: "deb_packages", + VersionsCount: 1, + HostsCount: 1, + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.5", Vulnerabilities: nil}, + }, + }, + }, resp.SoftwareTitles) + }) + + t.Run("GET /software/titles/:id", func(t *testing.T) { + // find the ID of "foo" + var softwareListResp listSoftwareTitlesResponse + s.DoJSON( + "GET", "/api/latest/fleet/software/titles", + listSoftwareTitlesRequest{}, + http.StatusOK, &softwareListResp, + "query", "foo", + ) + require.Equal(t, 1, softwareListResp.Count) + require.Len(t, softwareListResp.SoftwareTitles, 1) + fooTitle := softwareListResp.SoftwareTitles[0] + require.Equal(t, "foo", fooTitle.Name) + + // non-existent id is a 404 + var resp getSoftwareTitleResponse + s.DoJSON("GET", "/api/latest/fleet/software/titles/999", getSoftwareTitleRequest{}, http.StatusNotFound, &resp) + + // valid title + resp = getSoftwareTitleResponse{} + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/software/titles/%d", fooTitle.ID), getSoftwareTitleRequest{}, http.StatusOK, &resp) + softwareTitlesMatch([]fleet.SoftwareTitle{{ + Name: "foo", + Source: "homebrew", + VersionsCount: 2, + HostsCount: 2, + Versions: []fleet.SoftwareVersion{ + {Version: "0.0.1", Vulnerabilities: nil, HostsCount: ptr.Uint(2)}, + {Version: "0.0.3", Vulnerabilities: nil, HostsCount: ptr.Uint(1)}, + }}, + }, []fleet.SoftwareTitle{*resp.SoftwareTitle}) + + // find the ID of "bar" + softwareListResp = listSoftwareTitlesResponse{} + s.DoJSON( + "GET", "/api/latest/fleet/software/titles", + listSoftwareTitlesRequest{}, + http.StatusOK, &softwareListResp, + "query", "bar", + ) + require.Equal(t, 1, softwareListResp.Count) + require.Len(t, softwareListResp.SoftwareTitles, 1) + barTitle := softwareListResp.SoftwareTitles[0] + require.Equal(t, "bar", barTitle.Name) + + // valid title with vulnerabilities + resp = getSoftwareTitleResponse{} + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/software/titles/%d", barTitle.ID), getSoftwareTitleRequest{}, http.StatusOK, &resp) + softwareTitlesMatch([]fleet.SoftwareTitle{{ + Name: "bar", + Source: "apps", + VersionsCount: 1, + HostsCount: 1, + Versions: []fleet.SoftwareVersion{ + { + Version: "0.0.4", + Vulnerabilities: &fleet.SliceString{"cve-123-123-132"}, + HostsCount: ptr.Uint(1), + }, + }}, + }, []fleet.SoftwareTitle{*resp.SoftwareTitle}) + }) +} + // checks that the specified team/no-team has the Windows OS Updates profile with // the specified deadline/grace settings (or checks that it doesn't have the // profile if wantSettings is nil). It returns the profile_uuid if it exists, diff --git a/server/service/software_titles.go b/server/service/software_titles.go new file mode 100644 index 0000000000..926955c567 --- /dev/null +++ b/server/service/software_titles.go @@ -0,0 +1,126 @@ +package service + +import ( + "context" + "time" + + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" + "github.com/fleetdm/fleet/v4/server/fleet" +) + +///////////////////////////////////////////////////////////////////////////////// +// List Software Titles +///////////////////////////////////////////////////////////////////////////////// + +type listSoftwareTitlesRequest struct { + fleet.SoftwareTitleListOptions +} + +type listSoftwareTitlesResponse struct { + Meta *fleet.PaginationMetadata `json:"meta"` + Count int `json:"count"` + CountsUpdatedAt *time.Time `json:"counts_updated_at"` + SoftwareTitles []fleet.SoftwareTitle `json:"software_titles,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r listSoftwareTitlesResponse) error() error { return r.Err } + +func listSoftwareTitlesEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) { + req := request.(*listSoftwareTitlesRequest) + titles, count, meta, err := svc.ListSoftwareTitles(ctx, req.SoftwareTitleListOptions) + if err != nil { + return listSoftwareTitlesResponse{Err: err}, nil + } + + var latest time.Time + for _, sw := range titles { + if !sw.CountsUpdatedAt.IsZero() && sw.CountsUpdatedAt.After(latest) { + latest = sw.CountsUpdatedAt + } + } + listResp := listSoftwareTitlesResponse{ + SoftwareTitles: titles, + Count: count, + Meta: meta, + } + if !latest.IsZero() { + listResp.CountsUpdatedAt = &latest + } + + return listResp, nil +} + +func (svc *Service) ListSoftwareTitles( + ctx context.Context, + opt fleet.SoftwareTitleListOptions, +) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error) { + if err := svc.authz.Authorize(ctx, &fleet.AuthzSoftwareInventory{ + TeamID: opt.TeamID, + }, fleet.ActionRead); err != nil { + return nil, 0, nil, err + } + + if opt.TeamID != nil && *opt.TeamID != 0 { + lic, err := svc.License(ctx) + if err != nil { + return nil, 0, nil, ctxerr.Wrap(ctx, err, "get license") + } + if !lic.IsPremium() { + return nil, 0, nil, fleet.ErrMissingLicense + } + } + + // always include metadata for software titles + opt.ListOptions.IncludeMetadata = true + // cursor-based pagination is not supported for software titles + opt.ListOptions.After = "" + + titles, count, meta, err := svc.ds.ListSoftwareTitles(ctx, opt) + if err != nil { + return nil, 0, nil, err + } + + return titles, count, meta, nil +} + +///////////////////////////////////////////////////////////////////////////////// +// Get a Software Title +///////////////////////////////////////////////////////////////////////////////// + +type getSoftwareTitleRequest struct { + ID uint `url:"id"` +} + +type getSoftwareTitleResponse struct { + SoftwareTitle *fleet.SoftwareTitle `json:"software_title,omitempty"` + Err error `json:"error,omitempty"` +} + +func (r getSoftwareTitleResponse) error() error { return r.Err } + +func getSoftwareTitleEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) { + req := request.(*getSoftwareTitleRequest) + + software, err := svc.SoftwareTitleByID(ctx, req.ID) + if err != nil { + return getSoftwareTitleResponse{Err: err}, nil + } + + return getSoftwareTitleResponse{SoftwareTitle: software}, nil +} + +func (svc *Service) SoftwareTitleByID(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) { + // TODO: this is the autorization we do for GET /software, does it look right? + // checking with product here: https://github.com/fleetdm/fleet/issues/14674#issuecomment-1841395788 + if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil { + return nil, err + } + + software, err := svc.ds.SoftwareTitleByID(ctx, id) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "getting software title by id") + } + + return software, nil +} diff --git a/server/service/software_titles_test.go b/server/service/software_titles_test.go new file mode 100644 index 0000000000..b172af3b48 --- /dev/null +++ b/server/service/software_titles_test.go @@ -0,0 +1,161 @@ +package service + +import ( + "context" + "testing" + + "github.com/fleetdm/fleet/v4/server/contexts/license" + "github.com/fleetdm/fleet/v4/server/contexts/viewer" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/mock" + "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/stretchr/testify/require" +) + +func TestServiceSoftwareTitlesAuth(t *testing.T) { + ds := new(mock.Store) + + ds.ListSoftwareTitlesFunc = func(ctx context.Context, opt fleet.SoftwareTitleListOptions) ([]fleet.SoftwareTitle, int, *fleet.PaginationMetadata, error) { + return []fleet.SoftwareTitle{}, 0, &fleet.PaginationMetadata{}, nil + } + ds.SoftwareTitleByIDFunc = func(ctx context.Context, id uint) (*fleet.SoftwareTitle, error) { + return &fleet.SoftwareTitle{}, nil + } + + svc, ctx := newTestService(t, ds, nil, nil) + + for _, tc := range []struct { + name string + user *fleet.User + shouldFailGlobalRead bool + shouldFailTeamRead bool + }{ + { + name: "global-admin", + user: &fleet.User{ + ID: 1, + GlobalRole: ptr.String(fleet.RoleAdmin), + }, + shouldFailGlobalRead: false, + shouldFailTeamRead: false, + }, + { + name: "global-maintainer", + user: &fleet.User{ + ID: 1, + GlobalRole: ptr.String(fleet.RoleMaintainer), + }, + shouldFailGlobalRead: false, + shouldFailTeamRead: false, + }, + { + name: "global-observer", + user: &fleet.User{ + ID: 1, + GlobalRole: ptr.String(fleet.RoleObserver), + }, + shouldFailGlobalRead: false, + shouldFailTeamRead: false, + }, + { + name: "team-admin-belongs-to-team", + user: &fleet.User{ + ID: 1, + Teams: []fleet.UserTeam{{ + Team: fleet.Team{ID: 1}, + Role: fleet.RoleAdmin, + }}, + }, + shouldFailGlobalRead: true, + shouldFailTeamRead: false, + }, + { + name: "team-maintainer-belongs-to-team", + user: &fleet.User{ + ID: 1, + Teams: []fleet.UserTeam{{ + Team: fleet.Team{ID: 1}, + Role: fleet.RoleMaintainer, + }}, + }, + shouldFailGlobalRead: true, + shouldFailTeamRead: false, + }, + { + name: "team-observer-belongs-to-team", + user: &fleet.User{ + ID: 1, + Teams: []fleet.UserTeam{{ + Team: fleet.Team{ID: 1}, + Role: fleet.RoleObserver, + }}, + }, + shouldFailGlobalRead: true, + shouldFailTeamRead: false, + }, + { + name: "team-admin-does-not-belong-to-team", + user: &fleet.User{ + ID: 1, + Teams: []fleet.UserTeam{{ + Team: fleet.Team{ID: 2}, + Role: fleet.RoleAdmin, + }}, + }, + shouldFailGlobalRead: true, + shouldFailTeamRead: true, + }, + { + name: "team-maintainer-does-not-belong-to-team", + user: &fleet.User{ + ID: 1, + Teams: []fleet.UserTeam{{ + Team: fleet.Team{ID: 2}, + Role: fleet.RoleMaintainer, + }}, + }, + shouldFailGlobalRead: true, + shouldFailTeamRead: true, + }, + { + name: "team-observer-does-not-belong-to-team", + user: &fleet.User{ + ID: 1, + Teams: []fleet.UserTeam{{ + Team: fleet.Team{ID: 2}, + Role: fleet.RoleObserver, + }}, + }, + shouldFailGlobalRead: true, + shouldFailTeamRead: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + ctx := viewer.NewContext(ctx, viewer.Viewer{User: tc.user}) + premiumCtx := license.NewContext(ctx, &fleet.LicenseInfo{Tier: fleet.TierPremium}) + + // List all software titles. + _, _, _, err := svc.ListSoftwareTitles(ctx, fleet.SoftwareTitleListOptions{}) + checkAuthErr(t, tc.shouldFailGlobalRead, err) + + // List software for a team. + _, _, _, err = svc.ListSoftwareTitles(premiumCtx, fleet.SoftwareTitleListOptions{ + TeamID: ptr.Uint(1), + }) + checkAuthErr(t, tc.shouldFailTeamRead, err) + + // List software for a team should fail no matter what + // with a non-premium context + if !tc.shouldFailTeamRead { + _, _, _, err = svc.ListSoftwareTitles(ctx, fleet.SoftwareTitleListOptions{ + TeamID: ptr.Uint(1), + }) + require.ErrorContains(t, err, "Requires Fleet Premium license") + } + + // Get a software title + _, err = svc.SoftwareTitleByID(ctx, 1) + checkAuthErr(t, false, err) + }) + } +}