Adding tests for new live query endpoints. (#16823)

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

<!-- Note that API documentation changes are now addressed by the
product design team. -->

- [ ] Changes file added for user-visible changes in `changes/` or
`orbit/changes/`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- [ ] Documented any permissions changes (docs/Using
Fleet/manage-access.md)
- [ ] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [ ] Added support on fleet's osquery simulator `cmd/osquery-perf` for
new osquery data ingestion features.
- [ ] Added/updated tests
- [ ] If database migrations are included, checked table schema to
confirm autoupdate
- For database migrations:
- [ ] Checked schema for all modified table for columns that will
auto-update timestamps during migration.
- [ ] Confirmed that updating the timestamps is acceptable, and will not
cause unwanted side effects.
- [ ] Manual QA for all new/changed functionality
  - For Orbit and Fleet Desktop changes:
- [ ] Manual QA must be performed in the three main OSs, macOS, Windows
and Linux.
- [ ] Auto-update manual QA, from released version of component to new
version (see [tools/tuf/test](../tools/tuf/test/README.md)).
This commit is contained in:
Victor Lyuboslavsky 2024-02-14 09:43:21 -06:00 committed by GitHub
parent a5a7df4527
commit e1aac9c776
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 33 additions and 15 deletions

View file

@ -31,7 +31,7 @@ import (
func TestIntegrationLiveQueriesTestSuite(t *testing.T) {
testingSuite := new(liveQueriesTestSuite)
testingSuite.s = &testingSuite.Suite
testingSuite.withServer.s = &testingSuite.Suite
suite.Run(t, testingSuite)
}
@ -91,7 +91,8 @@ type liveQueryEndpoint int
const (
oldEndpoint liveQueryEndpoint = iota
oneQueryEndpoint
customQueryOneHostEndpoint
customQueryOneHostIdEndpoint
customQueryOneHostIdentifierEndpoint
)
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
@ -140,14 +141,18 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
defer wg.Done()
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
}()
} else { // customQueryOneHostEndpoint
} else { // customQueryOneHostId(.*)Endpoint
liveQueryRequest := runLiveQueryOnHostRequest{
Query: query,
}
url := fmt.Sprintf("/api/latest/fleet/hosts/%d/query", host.ID)
if endpoint == customQueryOneHostIdentifierEndpoint {
url = fmt.Sprintf("/api/latest/fleet/hosts/identifier/%s/query", host.UUID)
}
go func() {
defer wg.Done()
s.DoJSON(
"POST", fmt.Sprintf("/api/latest/fleet/hosts/identifier/%s/query", host.UUID), liveQueryRequest, http.StatusOK,
"POST", url, liveQueryRequest, http.StatusOK,
&liveQueryOnHostResp,
)
}()
@ -160,7 +165,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for range ticker.C {
if endpoint == customQueryOneHostEndpoint {
if endpoint == customQueryOneHostIdentifierEndpoint || endpoint == customQueryOneHostIdEndpoint {
campaign := fleet.DistributedQueryCampaign{}
err := mysql.ExecAdhocSQLWithError(
s.ds, func(q sqlx.ExtContext) error {
@ -239,7 +244,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
require.Len(t, liveQueryResp.Results[0].Results, 1)
result = liveQueryResp.Results[0].Results[0]
} else { // customQueryOneHostEndpoint
} else { // customQueryOneHostId(.*)Endpoint
assert.Empty(t, liveQueryOnHostResp.Error)
assert.Equal(t, host.ID, liveQueryOnHostResp.HostID)
assert.Equal(t, fleet.StatusOnline, liveQueryOnHostResp.Status)
@ -321,7 +326,8 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
s.Run("not saved query", func() { test(oneQueryEndpoint, false, true) })
s.Run("saved query without stats", func() { test(oneQueryEndpoint, true, false) })
s.Run("saved query with stats", func() { test(oneQueryEndpoint, true, true) })
s.Run("custom query", func() { test(customQueryOneHostEndpoint, false, false) })
s.Run("custom query by host id", func() { test(customQueryOneHostIdEndpoint, false, false) })
s.Run("custom query by host identifier", func() { test(customQueryOneHostIdentifierEndpoint, false, false) })
}
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
@ -695,6 +701,20 @@ func (s *liveQueriesTestSuite) TestLiveQueriesInvalidInputs() {
HostIDs: nil,
}
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp)
// Invalid raw query
liveQueryOnHostRequest := runLiveQueryOnHostRequest{
Query: " ",
}
liveQueryOnHostResp := runLiveQueryOnHostResponse{}
s.DoJSON(
"POST", fmt.Sprintf("/api/latest/fleet/hosts/%d/query", host.ID), liveQueryOnHostRequest, http.StatusBadRequest,
&liveQueryOnHostResp,
)
s.DoJSON(
"POST", fmt.Sprintf("/api/latest/fleet/hosts/identifier/%s/query", host.UUID), liveQueryOnHostRequest, http.StatusBadRequest,
&liveQueryOnHostResp,
)
}
// TestLiveQueriesFailsToAuthorize when an observer tries to run a live query

View file

@ -6,6 +6,7 @@ import (
"fmt"
"os"
"strconv"
"strings"
"sync"
"time"
@ -128,10 +129,6 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
func runLiveQueryOnHostEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
req := request.(*runLiveQueryOnHostRequest)
if req.Query == "" {
return nil, ctxerr.Wrap(ctx, badRequest("query is required"))
}
host, err := svc.HostLiteByIdentifier(ctx, req.Identifier)
if err != nil {
return nil, ctxerr.Wrap(ctx, badRequest(fmt.Sprintf("host not found: %s: %s", req.Identifier, err.Error())))
@ -143,10 +140,6 @@ func runLiveQueryOnHostEndpoint(ctx context.Context, request interface{}, svc fl
func runLiveQueryOnHostByIDEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
req := request.(*runLiveQueryOnHostByIDRequest)
if req.Query == "" {
return nil, ctxerr.Wrap(ctx, badRequest("query is required"))
}
host, err := svc.HostLiteByID(ctx, req.HostID)
if err != nil {
return nil, ctxerr.Wrap(ctx, badRequest(fmt.Sprintf("host not found: %d: %s", req.HostID, err.Error())))
@ -156,6 +149,11 @@ func runLiveQueryOnHostByIDEndpoint(ctx context.Context, request interface{}, sv
}
func runLiveQueryOnHost(svc fleet.Service, ctx context.Context, host *fleet.HostLite, query string) (errorer, error) {
query = strings.TrimSpace(query)
if query == "" {
return nil, ctxerr.Wrap(ctx, badRequest("query is required"))
}
res := runLiveQueryOnHostResponse{
HostID: host.ID,
Query: query,