From 84f45e54d093c44a8078f4d1eee929c3b39152a3 Mon Sep 17 00:00:00 2001 From: Tim Lee Date: Mon, 29 Apr 2024 12:37:25 -0600 Subject: [PATCH] 17745 queries backend (#18582) #17745 implement `merge_inherited` on the list queries endpoint to combine team and inherited queries. - [ ] Changes file added for user-visible changes in `changes/`, `orbit/changes/` or `ee/fleetd-chrome/changes`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [X] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [X] Added/updated tests - [X] Manual QA for all new/changed functionality --------- Co-authored-by: RachelElysia --- server/datastore/mysql/queries.go | 12 ++-- server/datastore/mysql/queries_test.go | 22 ++++++- server/fleet/app.go | 3 + server/fleet/service.go | 4 +- server/service/global_schedule.go | 2 +- server/service/integration_core_test.go | 8 +-- server/service/integration_enterprise_test.go | 61 ++++++++++++++++++- server/service/queries.go | 16 ++--- server/service/queries_test.go | 2 +- server/service/team_schedule.go | 2 +- 10 files changed, 109 insertions(+), 23 deletions(-) diff --git a/server/datastore/mysql/queries.go b/server/datastore/mysql/queries.go index 2a26884985..2ad343922c 100644 --- a/server/datastore/mysql/queries.go +++ b/server/datastore/mysql/queries.go @@ -4,11 +4,12 @@ import ( "context" "database/sql" "fmt" + "strings" + "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/go-kit/log/level" "github.com/jmoiron/sqlx" - "strings" ) const ( @@ -414,7 +415,6 @@ func (ds *Datastore) deleteQueryStats(ctx context.Context, queryIDs []uint) { level.Error(ds.logger).Log("msg", "error deleting aggregated stats", "err", err) } } - } // Query returns a single Query identified by id, if such exists. @@ -504,10 +504,14 @@ func (ds *Datastore) ListQueries(ctx context.Context, opt fleet.ListQueryOptions args := []interface{}{false, fleet.AggregatedStatsTypeScheduledQuery} whereClauses := "WHERE saved = true" - if opt.TeamID != nil { + switch { + case opt.TeamID != nil && opt.MergeInherited: + args = append(args, *opt.TeamID) + whereClauses += " AND (team_id = ? OR team_id IS NULL)" + case opt.TeamID != nil: args = append(args, *opt.TeamID) whereClauses += " AND team_id = ?" - } else { + default: whereClauses += " AND team_id IS NULL" } diff --git a/server/datastore/mysql/queries_test.go b/server/datastore/mysql/queries_test.go index 2096683477..f43c07fe7e 100644 --- a/server/datastore/mysql/queries_test.go +++ b/server/datastore/mysql/queries_test.go @@ -219,7 +219,6 @@ func testQueriesDelete(t *testing.T, ds *Datastore) { case <-time.After(10 * time.Second): t.Error("Timeout: stats not deleted for testQueriesDelete") } - } func testQueriesGetByName(t *testing.T, ds *Datastore) { @@ -765,6 +764,27 @@ func testListQueriesFiltersByTeamID(t *testing.T, ds *Datastore) { ) require.NoError(t, err) test.QueryElementsMatch(t, queries, []*fleet.Query{teamQ1, teamQ2, teamQ3}) + + // test merge inherited + queries, err = ds.ListQueries( + context.Background(), + fleet.ListQueryOptions{ + TeamID: &team.ID, + MergeInherited: true, + }, + ) + require.NoError(t, err) + test.QueryElementsMatch(t, queries, []*fleet.Query{globalQ1, globalQ2, globalQ3, teamQ1, teamQ2, teamQ3}) + + // merge inherited ignored for global queries + queries, err = ds.ListQueries( + context.Background(), + fleet.ListQueryOptions{ + MergeInherited: true, + }, + ) + require.NoError(t, err) + test.QueryElementsMatch(t, queries, []*fleet.Query{globalQ1, globalQ2, globalQ3}) } func testListQueriesFiltersByIsScheduled(t *testing.T, ds *Datastore) { diff --git a/server/fleet/app.go b/server/fleet/app.go index 6f99244500..ef5b06e5b2 100644 --- a/server/fleet/app.go +++ b/server/fleet/app.go @@ -1016,6 +1016,9 @@ type ListQueryOptions struct { TeamID *uint // IsScheduled filters queries that are meant to run at a set interval. IsScheduled *bool + // MergeInherited merges inherited global queries into the team list. Is only valid when TeamID + // is set. + MergeInherited bool } type ListActivitiesOptions struct { diff --git a/server/fleet/service.go b/server/fleet/service.go index 47d0e5fc5e..c7d02a7a97 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -270,7 +270,9 @@ type Service interface { // for distributed queries but not saved should not be returned). // When is set to scheduled != nil, then only scheduled queries will be returned if `*scheduled == true` // and only non-scheduled queries will be returned if `*scheduled == false`. - ListQueries(ctx context.Context, opt ListOptions, teamID *uint, scheduled *bool) ([]*Query, error) + // If mergeInherited is true and a teamID is provided, then queries from the global team will be + // included in the results. + ListQueries(ctx context.Context, opt ListOptions, teamID *uint, scheduled *bool, mergeInherited bool) ([]*Query, error) GetQuery(ctx context.Context, id uint) (*Query, error) // GetQueryReportResults returns all the stored results of a query for hosts the requestor has access to GetQueryReportResults(ctx context.Context, id uint) ([]HostQueryResultRow, error) diff --git a/server/service/global_schedule.go b/server/service/global_schedule.go index a8efa4c87d..c75d860486 100644 --- a/server/service/global_schedule.go +++ b/server/service/global_schedule.go @@ -37,7 +37,7 @@ func getGlobalScheduleEndpoint(ctx context.Context, request interface{}, svc fle } func (svc *Service) GetGlobalScheduledQueries(ctx context.Context, opts fleet.ListOptions) ([]*fleet.ScheduledQuery, error) { - queries, err := svc.ListQueries(ctx, opts, nil, ptr.Bool(true)) // teamID == nil means global + queries, err := svc.ListQueries(ctx, opts, nil, ptr.Bool(true), false) // teamID == nil means global if err != nil { return nil, err } diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index a9e4c6ffc7..4a739e189e 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -263,7 +263,7 @@ func (s *integrationTestSuite) TestQueryCreationLogsActivity() { } var createQueryResp createQueryResponse s.DoJSON("POST", "/api/latest/fleet/queries", ¶ms, http.StatusOK, &createQueryResp) - defer cleanupQuery(s, createQueryResp.Query.ID) + defer s.cleanupQuery(createQueryResp.Query.ID) activities := listActivitiesResponse{} s.DoJSON("GET", "/api/latest/fleet/activities", nil, http.StatusOK, &activities) @@ -1579,7 +1579,7 @@ func (s *integrationTestSuite) TestListHosts() { user1 := test.NewUser(t, s.ds, "Alice", "alice@example.com", true) q := test.NewQuery(t, s.ds, nil, "query1", "select 1", 0, true) - defer cleanupQuery(s, q.ID) + defer s.cleanupQuery(q.ID) globalPolicy0, err := s.ds.NewGlobalPolicy( context.Background(), &user1.ID, fleet.PolicyPayload{ QueryID: &q.ID, @@ -5791,7 +5791,7 @@ func (s *integrationTestSuite) TestQueriesBadRequests() { s.DoJSON("POST", "/api/latest/fleet/queries", reqQuery, http.StatusOK, &createQueryResp) require.NotNil(t, createQueryResp.Query) existingQueryID := createQueryResp.Query.ID - defer cleanupQuery(s, existingQueryID) + defer s.cleanupQuery(existingQueryID) for _, tc := range []struct { tname string @@ -9011,7 +9011,7 @@ func createSession(t *testing.T, uid uint, ds fleet.Datastore) *fleet.Session { return ssn } -func cleanupQuery(s *integrationTestSuite, queryID uint) { +func (s *integrationTestSuite) cleanupQuery(queryID uint) { var delResp deleteQueryByIDResponse s.DoJSON("DELETE", fmt.Sprintf("/api/latest/fleet/queries/id/%d", queryID), nil, http.StatusOK, &delResp) } diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index b50bd25c7b..90b4c14b16 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -853,6 +853,53 @@ func (s *integrationEnterpriseTestSuite) TestTeamPolicies() { require.Len(t, ts.Policies, 0) } +func (s *integrationEnterpriseTestSuite) TestTeamQueries() { + t := s.T() + + team1, err := s.ds.NewTeam(context.Background(), &fleet.Team{ + ID: 42, + Name: "team1" + t.Name(), + Description: "desc team1", + }) + require.NoError(t, err) + + oldToken := s.token + t.Cleanup(func() { + s.token = oldToken + }) + + // create global query + params := fleet.QueryPayload{ + Name: ptr.String("global1"), + Query: ptr.String("select * from time;"), + } + var createQueryResp createQueryResponse + s.DoJSON("POST", "/api/latest/fleet/queries", ¶ms, http.StatusOK, &createQueryResp) + defer s.cleanupQuery(createQueryResp.Query.ID) + + // create team query + params = fleet.QueryPayload{ + Name: ptr.String("team1"), + Query: ptr.String("select * from time;"), + TeamID: ptr.Uint(team1.ID), + } + createQueryResp = createQueryResponse{} + s.DoJSON("POST", "/api/latest/fleet/queries", ¶ms, http.StatusOK, &createQueryResp) + defer s.cleanupQuery(createQueryResp.Query.ID) + + // list team queries + var listQueriesResp listQueriesResponse + s.DoJSON("GET", "/api/latest/fleet/queries", nil, http.StatusOK, &listQueriesResp, "team_id", fmt.Sprint(team1.ID)) + require.Len(t, listQueriesResp.Queries, 1) + assert.Equal(t, "team1", listQueriesResp.Queries[0].Name) + + // list merged team queries + s.DoJSON("GET", "/api/latest/fleet/queries", nil, http.StatusOK, &listQueriesResp, "team_id", fmt.Sprint(team1.ID), "merge_inherited", "true", "order_key", "team_id", "order_direction", "desc") + require.Len(t, listQueriesResp.Queries, 2) + assert.Equal(t, "team1", listQueriesResp.Queries[0].Name) + assert.Equal(t, "global1", listQueriesResp.Queries[1].Name) +} + func (s *integrationEnterpriseTestSuite) TestModifyTeamEnrollSecrets() { t := s.T() @@ -2840,7 +2887,8 @@ func (s *integrationEnterpriseTestSuite) TestMDMMacOSUpdates() { // edited macos min version activity got created s.lastActivityMatches(fleet.ActivityTypeEditedMacOSMinVersion{}.ActivityName(), `{"deadline":"2022-01-01", "minimum_version":"12.3.1", "team_id": null, "team_name": null}`, 0) s.assertMacOSUpdatesDeclaration(nil, &fleet.MacOSUpdates{ - MinimumVersion: optjson.SetString("12.3.1"), Deadline: optjson.SetString("2022-01-01")}) + MinimumVersion: optjson.SetString("12.3.1"), Deadline: optjson.SetString("2022-01-01"), + }) // get the appconfig acResp = appConfigResponse{} @@ -2864,7 +2912,8 @@ func (s *integrationEnterpriseTestSuite) TestMDMMacOSUpdates() { // another edited macos min version activity got created lastActivity = s.lastActivityMatches(fleet.ActivityTypeEditedMacOSMinVersion{}.ActivityName(), `{"deadline":"2024-01-01", "minimum_version":"12.3.1", "team_id": null, "team_name": null}`, 0) s.assertMacOSUpdatesDeclaration(nil, &fleet.MacOSUpdates{ - MinimumVersion: optjson.SetString("12.3.1"), Deadline: optjson.SetString("2024-01-01")}) + MinimumVersion: optjson.SetString("12.3.1"), Deadline: optjson.SetString("2024-01-01"), + }) // update something unrelated - the transparency url acResp = appConfigResponse{} @@ -2875,7 +2924,8 @@ func (s *integrationEnterpriseTestSuite) TestMDMMacOSUpdates() { // no activity got created s.lastActivityMatches("", ``, lastActivity) s.assertMacOSUpdatesDeclaration(nil, &fleet.MacOSUpdates{ - MinimumVersion: optjson.SetString("12.3.1"), Deadline: optjson.SetString("2024-01-01")}) + MinimumVersion: optjson.SetString("12.3.1"), Deadline: optjson.SetString("2024-01-01"), + }) // clear the macos requirement acResp = appConfigResponse{} @@ -8653,3 +8703,8 @@ func triggerAndWait(ctx context.Context, t *testing.T, ds fleet.Datastore, s *sc } } } + +func (s *integrationEnterpriseTestSuite) cleanupQuery(queryID uint) { + var delResp deleteQueryByIDResponse + s.DoJSON("DELETE", fmt.Sprintf("/api/latest/fleet/queries/id/%d", queryID), nil, http.StatusOK, &delResp) +} diff --git a/server/service/queries.go b/server/service/queries.go index a4d39fef83..988f3d232b 100644 --- a/server/service/queries.go +++ b/server/service/queries.go @@ -58,7 +58,8 @@ func (svc *Service) GetQuery(ctx context.Context, id uint) (*fleet.Query, error) type listQueriesRequest struct { ListOptions fleet.ListOptions `url:"list_options"` // TeamID url argument set to 0 means global. - TeamID uint `query:"team_id,optional"` + TeamID uint `query:"team_id,optional"` + MergeInherited bool `query:"merge_inherited,optional"` } type listQueriesResponse struct { @@ -76,7 +77,7 @@ func listQueriesEndpoint(ctx context.Context, request interface{}, svc fleet.Ser teamID = &req.TeamID } - queries, err := svc.ListQueries(ctx, req.ListOptions, teamID, nil) + queries, err := svc.ListQueries(ctx, req.ListOptions, teamID, nil, req.MergeInherited) if err != nil { return listQueriesResponse{Err: err}, nil } @@ -90,7 +91,7 @@ func listQueriesEndpoint(ctx context.Context, request interface{}, svc fleet.Ser }, nil } -func (svc *Service) ListQueries(ctx context.Context, opt fleet.ListOptions, teamID *uint, scheduled *bool) ([]*fleet.Query, error) { +func (svc *Service) ListQueries(ctx context.Context, opt fleet.ListOptions, teamID *uint, scheduled *bool, mergeInherited bool) ([]*fleet.Query, error) { // Check the user is allowed to list queries on the given team. if err := svc.authz.Authorize(ctx, &fleet.Query{ TeamID: teamID, @@ -99,9 +100,10 @@ func (svc *Service) ListQueries(ctx context.Context, opt fleet.ListOptions, team } queries, err := svc.ds.ListQueries(ctx, fleet.ListQueryOptions{ - ListOptions: opt, - TeamID: teamID, - IsScheduled: scheduled, + ListOptions: opt, + TeamID: teamID, + IsScheduled: scheduled, + MergeInherited: mergeInherited, }) if err != nil { return nil, err @@ -733,7 +735,7 @@ func getQuerySpecsEndpoint(ctx context.Context, request interface{}, svc fleet.S } func (svc *Service) GetQuerySpecs(ctx context.Context, teamID *uint) ([]*fleet.QuerySpec, error) { - queries, err := svc.ListQueries(ctx, fleet.ListOptions{}, teamID, nil) + queries, err := svc.ListQueries(ctx, fleet.ListOptions{}, teamID, nil, false) if err != nil { return nil, ctxerr.Wrap(ctx, err, "getting queries") } diff --git a/server/service/queries_test.go b/server/service/queries_test.go index 0fc1a44aec..9b9cdfb1c9 100644 --- a/server/service/queries_test.go +++ b/server/service/queries_test.go @@ -632,7 +632,7 @@ func TestQueryAuth(t *testing.T) { _, err = svc.QueryReportIsClipped(ctx, tt.qid) checkAuthErr(t, tt.shouldFailRead, err) - _, err = svc.ListQueries(ctx, fleet.ListOptions{}, query.TeamID, nil) + _, err = svc.ListQueries(ctx, fleet.ListOptions{}, query.TeamID, nil, false) checkAuthErr(t, tt.shouldFailRead, err) teamName := "" diff --git a/server/service/team_schedule.go b/server/service/team_schedule.go index 24ad3cde3b..da31740a77 100644 --- a/server/service/team_schedule.go +++ b/server/service/team_schedule.go @@ -47,7 +47,7 @@ func (svc Service) GetTeamScheduledQueries(ctx context.Context, teamID uint, opt if teamID != 0 { teamID_ = &teamID } - queries, err := svc.ListQueries(ctx, opts, teamID_, ptr.Bool(true)) + queries, err := svc.ListQueries(ctx, opts, teamID_, ptr.Bool(true), false) if err != nil { return nil, err }