From e1aac9c776252192dc751e397ace411c8e922481 Mon Sep 17 00:00:00 2001 From: Victor Lyuboslavsky Date: Wed, 14 Feb 2024 09:43:21 -0600 Subject: [PATCH] Adding tests for new live query endpoints. (#16823) # Checklist for submitter If some of the following don't apply, delete the relevant line. - [ ] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [ ] Documented any permissions changes (docs/Using Fleet/manage-access.md) - [ ] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [ ] Added support on fleet's osquery simulator `cmd/osquery-perf` for new osquery data ingestion features. - [ ] Added/updated tests - [ ] If database migrations are included, checked table schema to confirm autoupdate - For database migrations: - [ ] Checked schema for all modified table for columns that will auto-update timestamps during migration. - [ ] Confirmed that updating the timestamps is acceptable, and will not cause unwanted side effects. - [ ] Manual QA for all new/changed functionality - For Orbit and Fleet Desktop changes: - [ ] Manual QA must be performed in the three main OSs, macOS, Windows and Linux. - [ ] Auto-update manual QA, from released version of component to new version (see [tools/tuf/test](../tools/tuf/test/README.md)). --- .../service/integration_live_queries_test.go | 34 +++++++++++++++---- server/service/live_queries.go | 14 ++++---- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/server/service/integration_live_queries_test.go b/server/service/integration_live_queries_test.go index dc3209281f..91bb4a2389 100644 --- a/server/service/integration_live_queries_test.go +++ b/server/service/integration_live_queries_test.go @@ -31,7 +31,7 @@ import ( func TestIntegrationLiveQueriesTestSuite(t *testing.T) { testingSuite := new(liveQueriesTestSuite) - testingSuite.s = &testingSuite.Suite + testingSuite.withServer.s = &testingSuite.Suite suite.Run(t, testingSuite) } @@ -91,7 +91,8 @@ type liveQueryEndpoint int const ( oldEndpoint liveQueryEndpoint = iota oneQueryEndpoint - customQueryOneHostEndpoint + customQueryOneHostIdEndpoint + customQueryOneHostIdentifierEndpoint ) func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { @@ -140,14 +141,18 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { defer wg.Done() s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp) }() - } else { // customQueryOneHostEndpoint + } else { // customQueryOneHostId(.*)Endpoint liveQueryRequest := runLiveQueryOnHostRequest{ Query: query, } + url := fmt.Sprintf("/api/latest/fleet/hosts/%d/query", host.ID) + if endpoint == customQueryOneHostIdentifierEndpoint { + url = fmt.Sprintf("/api/latest/fleet/hosts/identifier/%s/query", host.UUID) + } go func() { defer wg.Done() s.DoJSON( - "POST", fmt.Sprintf("/api/latest/fleet/hosts/identifier/%s/query", host.UUID), liveQueryRequest, http.StatusOK, + "POST", url, liveQueryRequest, http.StatusOK, &liveQueryOnHostResp, ) }() @@ -160,7 +165,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() for range ticker.C { - if endpoint == customQueryOneHostEndpoint { + if endpoint == customQueryOneHostIdentifierEndpoint || endpoint == customQueryOneHostIdEndpoint { campaign := fleet.DistributedQueryCampaign{} err := mysql.ExecAdhocSQLWithError( s.ds, func(q sqlx.ExtContext) error { @@ -239,7 +244,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID) require.Len(t, liveQueryResp.Results[0].Results, 1) result = liveQueryResp.Results[0].Results[0] - } else { // customQueryOneHostEndpoint + } else { // customQueryOneHostId(.*)Endpoint assert.Empty(t, liveQueryOnHostResp.Error) assert.Equal(t, host.ID, liveQueryOnHostResp.HostID) assert.Equal(t, fleet.StatusOnline, liveQueryOnHostResp.Status) @@ -321,7 +326,8 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { s.Run("not saved query", func() { test(oneQueryEndpoint, false, true) }) s.Run("saved query without stats", func() { test(oneQueryEndpoint, true, false) }) s.Run("saved query with stats", func() { test(oneQueryEndpoint, true, true) }) - s.Run("custom query", func() { test(customQueryOneHostEndpoint, false, false) }) + s.Run("custom query by host id", func() { test(customQueryOneHostIdEndpoint, false, false) }) + s.Run("custom query by host identifier", func() { test(customQueryOneHostIdentifierEndpoint, false, false) }) } func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() { @@ -695,6 +701,20 @@ func (s *liveQueriesTestSuite) TestLiveQueriesInvalidInputs() { HostIDs: nil, } s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp) + + // Invalid raw query + liveQueryOnHostRequest := runLiveQueryOnHostRequest{ + Query: " ", + } + liveQueryOnHostResp := runLiveQueryOnHostResponse{} + s.DoJSON( + "POST", fmt.Sprintf("/api/latest/fleet/hosts/%d/query", host.ID), liveQueryOnHostRequest, http.StatusBadRequest, + &liveQueryOnHostResp, + ) + s.DoJSON( + "POST", fmt.Sprintf("/api/latest/fleet/hosts/identifier/%s/query", host.UUID), liveQueryOnHostRequest, http.StatusBadRequest, + &liveQueryOnHostResp, + ) } // TestLiveQueriesFailsToAuthorize when an observer tries to run a live query diff --git a/server/service/live_queries.go b/server/service/live_queries.go index 811eb5122b..06b1cabf8b 100644 --- a/server/service/live_queries.go +++ b/server/service/live_queries.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "strconv" + "strings" "sync" "time" @@ -128,10 +129,6 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se func runLiveQueryOnHostEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) { req := request.(*runLiveQueryOnHostRequest) - if req.Query == "" { - return nil, ctxerr.Wrap(ctx, badRequest("query is required")) - } - host, err := svc.HostLiteByIdentifier(ctx, req.Identifier) if err != nil { return nil, ctxerr.Wrap(ctx, badRequest(fmt.Sprintf("host not found: %s: %s", req.Identifier, err.Error()))) @@ -143,10 +140,6 @@ func runLiveQueryOnHostEndpoint(ctx context.Context, request interface{}, svc fl func runLiveQueryOnHostByIDEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) { req := request.(*runLiveQueryOnHostByIDRequest) - if req.Query == "" { - return nil, ctxerr.Wrap(ctx, badRequest("query is required")) - } - host, err := svc.HostLiteByID(ctx, req.HostID) if err != nil { return nil, ctxerr.Wrap(ctx, badRequest(fmt.Sprintf("host not found: %d: %s", req.HostID, err.Error()))) @@ -156,6 +149,11 @@ func runLiveQueryOnHostByIDEndpoint(ctx context.Context, request interface{}, sv } func runLiveQueryOnHost(svc fleet.Service, ctx context.Context, host *fleet.HostLite, query string) (errorer, error) { + query = strings.TrimSpace(query) + if query == "" { + return nil, ctxerr.Wrap(ctx, badRequest("query is required")) + } + res := runLiveQueryOnHostResponse{ HostID: host.ID, Query: query,