From 967eddcb37cb2ed3bb03508917fc60ff9eaf6c59 Mon Sep 17 00:00:00 2001 From: Victor Lyuboslavsky Date: Tue, 13 Feb 2024 22:45:07 -0600 Subject: [PATCH] New live query API endpoint for custom query SQL. (#16810) #16805 - [X] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [X] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [X] Added/updated tests - [X] Manual QA for all new/changed functionality --------- Co-authored-by: Lucas Rodriguez --- changes/16805-new-live-query-on-host-endpoint | 1 + server/datastore/mysql/hosts.go | 58 +++++++++ server/fleet/datastore.go | 5 + server/fleet/hosts.go | 14 +++ server/fleet/service.go | 9 +- server/mock/datastore_mock.go | 24 ++++ server/service/handler.go | 2 + server/service/hosts.go | 38 ++++++ .../service/integration_live_queries_test.go | 94 ++++++++++---- server/service/live_queries.go | 118 +++++++++++++++++- 10 files changed, 335 insertions(+), 28 deletions(-) create mode 100644 changes/16805-new-live-query-on-host-endpoint diff --git a/changes/16805-new-live-query-on-host-endpoint b/changes/16805-new-live-query-on-host-endpoint new file mode 100644 index 0000000000..84918569b3 --- /dev/null +++ b/changes/16805-new-live-query-on-host-endpoint @@ -0,0 +1 @@ +* Add two new API endpoints to run a live query SQL on one host: `POST /api/latest/fleet/hosts/identifier/{identifier}/query` and `POST /api/_version_/fleet/hosts/{id}/query`. diff --git a/server/datastore/mysql/hosts.go b/server/datastore/mysql/hosts.go index c2af410d7d..af68b08f80 100644 --- a/server/datastore/mysql/hosts.go +++ b/server/datastore/mysql/hosts.go @@ -4905,3 +4905,61 @@ func (ds *Datastore) GetHostHealth(ctx context.Context, id uint) (*fleet.HostHea return &hh, nil } + +func (ds *Datastore) HostLiteByIdentifier(ctx context.Context, identifier string) (*fleet.HostLite, error) { + return ds.loadHostLite(ctx, nil, &identifier) +} + +func (ds *Datastore) HostLiteByID(ctx context.Context, id uint) (*fleet.HostLite, error) { + return ds.loadHostLite(ctx, &id, nil) +} + +func (ds *Datastore) loadHostLite(ctx context.Context, id *uint, identifier *string) (*fleet.HostLite, error) { + if id == nil && identifier == nil { + return nil, errors.New("must set one of id or identifier") + } + if id != nil && identifier != nil { + return nil, errors.New("cannot set both id and identifier") + } + stmt := ` + SELECT + h.id, + h.team_id, + h.osquery_host_id, + h.node_key, + h.hostname, + h.uuid, + h.hardware_serial, + h.distributed_interval, + h.config_tls_refresh, + COALESCE(hst.seen_time, h.created_at) AS seen_time + FROM hosts h + LEFT JOIN host_seen_times hst ON (h.id = hst.host_id) + %s + LIMIT 1 + ` + var ( + arg interface{} + whereClause string + ) + if identifier != nil { + whereClause = "WHERE ? IN (h.hostname, h.osquery_host_id, h.node_key, h.uuid, h.hardware_serial)" + arg = identifier + } else { + whereClause = "WHERE id = ?" + arg = id + } + host := &fleet.HostLite{} + err := sqlx.GetContext(ctx, ds.reader(ctx), host, fmt.Sprintf(stmt, whereClause), arg) + if err != nil { + if err == sql.ErrNoRows { + if identifier != nil { + return nil, ctxerr.Wrap(ctx, notFound("Host").WithName(*identifier)) + } + return nil, ctxerr.Wrap(ctx, notFound("Host").WithID(*id)) + } + return nil, ctxerr.Wrap(ctx, err, "get host lite") + } + + return host, nil +} diff --git a/server/fleet/datastore.go b/server/fleet/datastore.go index 7b48904199..c4a1aa5840 100644 --- a/server/fleet/datastore.go +++ b/server/fleet/datastore.go @@ -258,6 +258,11 @@ type Datastore interface { // HostByIdentifier returns one host matching the provided identifier. Possible matches can be on // osquery_host_identifier, node_key, UUID, or hostname. HostByIdentifier(ctx context.Context, identifier string) (*Host, error) + // HostLiteByIdentifier returns a host and a subset of its fields using an "identifier" string. + // The identifier string will be matched against the hostname, osquery_host_id, node_key, uuid and hardware_serial columns. + HostLiteByIdentifier(ctx context.Context, identifier string) (*HostLite, error) + // HostLiteByIdentifier returns a host and a subset of its fields from its id. + HostLiteByID(ctx context.Context, id uint) (*HostLite, error) // AddHostsToTeam adds hosts to an existing team, clearing their team settings if teamID is nil. AddHostsToTeam(ctx context.Context, teamID *uint, hostIDs []uint) error diff --git a/server/fleet/hosts.go b/server/fleet/hosts.go index 0e96d18fda..bdb6dc2259 100644 --- a/server/fleet/hosts.go +++ b/server/fleet/hosts.go @@ -1201,3 +1201,17 @@ type HostMacOSProfile struct { // InstallDate is the date the profile was installed on the host as reported by the host's clock. InstallDate time.Time `json:"install_date" db:"install_date"` } + +// HostLite contains a subset of Host fields. +type HostLite struct { + ID uint `db:"id"` + TeamID *uint `db:"team_id"` + Hostname string `db:"hostname"` + OsqueryHostID string `db:"osquery_host_id"` + NodeKey string `db:"node_key"` + UUID string `db:"uuid"` + HardwareSerial string `db:"hardware_serial"` + SeenTime time.Time `db:"seen_time"` + DistributedInterval uint `db:"distributed_interval"` + ConfigTLSRefresh uint `db:"config_tls_refresh"` +} diff --git a/server/fleet/service.go b/server/fleet/service.go index ebe359cd32..e0e870d20c 100644 --- a/server/fleet/service.go +++ b/server/fleet/service.go @@ -308,7 +308,9 @@ type Service interface { GetCampaignReader(ctx context.Context, campaign *DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error) CompleteCampaign(ctx context.Context, campaign *DistributedQueryCampaign) error - RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]QueryCampaignResult, int, error) + RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, query string, hostIDs []uint, deadline time.Duration) ( + []QueryCampaignResult, int, error, + ) // ///////////////////////////////////////////////////////////////////////////// // AgentOptionsService @@ -364,6 +366,11 @@ type Service interface { // device-authenticated API), or manually by the user (via the // user-authenticated API). SetCustomHostDeviceMapping(ctx context.Context, hostID uint, email string) ([]*HostDeviceMapping, error) + // HostLiteByIdentifier returns a host and a subset of its fields using an "identifier" string. + // The identifier string will be matched against the Hostname, OsqueryHostID, NodeKey, UUID and HardwareSerial fields. + HostLiteByIdentifier(ctx context.Context, identifier string) (*HostLite, error) + // HostLiteByIdentifier returns a host and a subset of its fields from its id. + HostLiteByID(ctx context.Context, id uint) (*HostLite, error) // ListDevicePolicies lists all policies for the given host, including passing / failing summaries ListDevicePolicies(ctx context.Context, host *Host) ([]*HostPolicy, error) diff --git a/server/mock/datastore_mock.go b/server/mock/datastore_mock.go index da3b32cd6e..f527e39272 100644 --- a/server/mock/datastore_mock.go +++ b/server/mock/datastore_mock.go @@ -192,6 +192,10 @@ type HostIDsByOSVersionFunc func(ctx context.Context, osVersion fleet.OSVersion, type HostByIdentifierFunc func(ctx context.Context, identifier string) (*fleet.Host, error) +type HostLiteByIdentifierFunc func(ctx context.Context, identifier string) (*fleet.HostLite, error) + +type HostLiteByIDFunc func(ctx context.Context, id uint) (*fleet.HostLite, error) + type AddHostsToTeamFunc func(ctx context.Context, teamID *uint, hostIDs []uint) error type TotalAndUnseenHostsSinceFunc func(ctx context.Context, daysCount int) (total int, unseen int, err error) @@ -1078,6 +1082,12 @@ type DataStore struct { HostByIdentifierFunc HostByIdentifierFunc HostByIdentifierFuncInvoked bool + HostLiteByIdentifierFunc HostLiteByIdentifierFunc + HostLiteByIdentifierFuncInvoked bool + + HostLiteByIDFunc HostLiteByIDFunc + HostLiteByIDFuncInvoked bool + AddHostsToTeamFunc AddHostsToTeamFunc AddHostsToTeamFuncInvoked bool @@ -2626,6 +2636,20 @@ func (s *DataStore) HostByIdentifier(ctx context.Context, identifier string) (*f return s.HostByIdentifierFunc(ctx, identifier) } +func (s *DataStore) HostLiteByIdentifier(ctx context.Context, identifier string) (*fleet.HostLite, error) { + s.mu.Lock() + s.HostLiteByIdentifierFuncInvoked = true + s.mu.Unlock() + return s.HostLiteByIdentifierFunc(ctx, identifier) +} + +func (s *DataStore) HostLiteByID(ctx context.Context, id uint) (*fleet.HostLite, error) { + s.mu.Lock() + s.HostLiteByIDFuncInvoked = true + s.mu.Unlock() + return s.HostLiteByIDFunc(ctx, id) +} + func (s *DataStore) AddHostsToTeam(ctx context.Context, teamID *uint, hostIDs []uint) error { s.mu.Lock() s.AddHostsToTeamFuncInvoked = true diff --git a/server/service/handler.go b/server/service/handler.go index 45ff9fe669..d6ea2100ed 100644 --- a/server/service/handler.go +++ b/server/service/handler.go @@ -378,6 +378,8 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC ue.GET("/api/_version_/fleet/hosts/count", countHostsEndpoint, countHostsRequest{}) ue.POST("/api/_version_/fleet/hosts/search", searchHostsEndpoint, searchHostsRequest{}) ue.GET("/api/_version_/fleet/hosts/identifier/{identifier}", hostByIdentifierEndpoint, hostByIdentifierRequest{}) + ue.POST("/api/_version_/fleet/hosts/identifier/{identifier}/query", runLiveQueryOnHostEndpoint, runLiveQueryOnHostRequest{}) + ue.POST("/api/_version_/fleet/hosts/{id:[0-9]+}/query", runLiveQueryOnHostByIDEndpoint, runLiveQueryOnHostByIDRequest{}) ue.DELETE("/api/_version_/fleet/hosts/{id:[0-9]+}", deleteHostEndpoint, deleteHostRequest{}) ue.POST("/api/_version_/fleet/hosts/transfer", addHostsToTeamEndpoint, addHostsToTeamRequest{}) ue.POST("/api/_version_/fleet/hosts/transfer/filter", addHostsToTeamByFilterEndpoint, addHostsToTeamByFilterRequest{}) diff --git a/server/service/hosts.go b/server/service/hosts.go index db94ca3f28..cf8a45e633 100644 --- a/server/service/hosts.go +++ b/server/service/hosts.go @@ -2097,3 +2097,41 @@ func (svc *Service) GetHostHealth(ctx context.Context, id uint) (*fleet.HostHeal return hh, nil } + +func (svc *Service) HostLiteByIdentifier(ctx context.Context, identifier string) (*fleet.HostLite, error) { + if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil { + return nil, err + } + + host, err := svc.ds.HostLiteByIdentifier(ctx, identifier) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "get host by identifier") + } + + if err := svc.authz.Authorize(ctx, fleet.Host{ + TeamID: host.TeamID, + }, fleet.ActionRead); err != nil { + return nil, err + } + + return host, nil +} + +func (svc *Service) HostLiteByID(ctx context.Context, id uint) (*fleet.HostLite, error) { + if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil { + return nil, err + } + + host, err := svc.ds.HostLiteByID(ctx, id) + if err != nil { + return nil, ctxerr.Wrap(ctx, err, "get host by id") + } + + if err := svc.authz.Authorize(ctx, fleet.Host{ + TeamID: host.TeamID, + }, fleet.ActionRead); err != nil { + return nil, err + } + + return host, nil +} diff --git a/server/service/integration_live_queries_test.go b/server/service/integration_live_queries_test.go index a427a22500..dc3209281f 100644 --- a/server/service/integration_live_queries_test.go +++ b/server/service/integration_live_queries_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "github.com/jmoiron/sqlx" "math" "math/rand" "net/http" @@ -23,6 +22,7 @@ import ( "github.com/fleetdm/fleet/v4/server/ptr" "github.com/fleetdm/fleet/v4/server/pubsub" "github.com/google/uuid" + "github.com/jmoiron/sqlx" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -86,8 +86,16 @@ func (s *liveQueriesTestSuite) TearDownTest() { s.lq.Mock = mock.Mock{} } +type liveQueryEndpoint int + +const ( + oldEndpoint liveQueryEndpoint = iota + oneQueryEndpoint + customQueryOneHostEndpoint +) + func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { - test := func(newEndpoint bool, savedQuery bool, hasStats bool) { + test := func(endpoint liveQueryEndpoint, savedQuery bool, hasStats bool) { t := s.T() host := s.hosts[0] @@ -114,7 +122,8 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { oneLiveQueryResp := runOneLiveQueryResponse{} liveQueryResp := runLiveQueryResponse{} - if newEndpoint { + liveQueryOnHostResp := runLiveQueryOnHostResponse{} + if endpoint == oneQueryEndpoint { liveQueryRequest := runOneLiveQueryRequest{ HostIDs: []uint{host.ID}, } @@ -122,7 +131,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { defer wg.Done() s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), liveQueryRequest, http.StatusOK, &oneLiveQueryResp) }() - } else { + } else if endpoint == oldEndpoint { liveQueryRequest := runLiveQueryRequest{ QueryIDs: []uint{q1.ID}, HostIDs: []uint{host.ID}, @@ -131,13 +140,48 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { defer wg.Done() s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp) }() + } else { // customQueryOneHostEndpoint + liveQueryRequest := runLiveQueryOnHostRequest{ + Query: query, + } + go func() { + defer wg.Done() + s.DoJSON( + "POST", fmt.Sprintf("/api/latest/fleet/hosts/identifier/%s/query", host.UUID), liveQueryRequest, http.StatusOK, + &liveQueryOnHostResp, + ) + }() } // For loop, waiting for campaign to be created. var cid string cidChannel := make(chan string) go func() { - for { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + for range ticker.C { + if endpoint == customQueryOneHostEndpoint { + campaign := fleet.DistributedQueryCampaign{} + err := mysql.ExecAdhocSQLWithError( + s.ds, func(q sqlx.ExtContext) error { + return sqlx.GetContext( + context.Background(), q, &campaign, + `SELECT * FROM distributed_query_campaigns WHERE status = ? ORDER BY id DESC LIMIT 1`, + fleet.QueryRunning, + ) + }, + ) + if errors.Is(err, sql.ErrNoRows) { + continue + } + if err != nil { + t.Error("Error selecting from distributed_query_campaigns", err) + return + } + q1.ID = campaign.QueryID + cidChannel <- fmt.Sprint(campaign.ID) + return + } campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID) require.NoError(t, err) @@ -171,9 +215,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { hostDistributedQueryPrefix + cid: 0, hostDistributedQueryPrefix + "9999": "0", }, - Messages: map[string]string{ - hostDistributedQueryPrefix + cid: "some msg", - }, + Messages: map[string]string{}, Stats: map[string]*fleet.Stats{ hostDistributedQueryPrefix + cid: stats, }, @@ -184,19 +226,28 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { wg.Wait() var result fleet.QueryResult - if newEndpoint { + if endpoint == oneQueryEndpoint { assert.Equal(t, q1.ID, oneLiveQueryResp.QueryID) assert.Equal(t, 1, oneLiveQueryResp.TargetedHostCount) assert.Equal(t, 1, oneLiveQueryResp.RespondedHostCount) require.Len(t, oneLiveQueryResp.Results, 1) result = oneLiveQueryResp.Results[0] - } else { + } else if endpoint == oldEndpoint { require.Len(t, liveQueryResp.Results, 1) assert.Equal(t, 1, liveQueryResp.Summary.TargetedHostCount) assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount) assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID) require.Len(t, liveQueryResp.Results[0].Results, 1) result = liveQueryResp.Results[0].Results[0] + } else { // customQueryOneHostEndpoint + assert.Empty(t, liveQueryOnHostResp.Error) + assert.Equal(t, host.ID, liveQueryOnHostResp.HostID) + assert.Equal(t, fleet.StatusOnline, liveQueryOnHostResp.Status) + assert.Equal(t, query, liveQueryOnHostResp.Query) + result = fleet.QueryResult{ + HostID: liveQueryOnHostResp.HostID, + Rows: liveQueryOnHostResp.Rows, + } } assert.Equal(t, host.ID, result.HostID) require.Len(t, result.Rows, 1) @@ -207,7 +258,9 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { var activity *fleet.ActivityTypeLiveQuery activityUpdated := make(chan *fleet.ActivityTypeLiveQuery) go func() { - for { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + for range ticker.C { details := json.RawMessage{} err := mysql.ExecAdhocSQLWithError( s.ds, func(q sqlx.ExtContext) error { @@ -262,12 +315,13 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() { } } } - s.Run("not saved query (old)", func() { test(false, false, true) }) - s.Run("saved query without stats (old)", func() { test(false, true, false) }) - s.Run("saved query with stats (old)", func() { test(false, true, true) }) - s.Run("not saved query", func() { test(true, false, true) }) - s.Run("saved query without stats", func() { test(true, true, false) }) - s.Run("saved query with stats", func() { test(true, true, true) }) + s.Run("not saved query (old)", func() { test(oldEndpoint, false, true) }) + s.Run("saved query without stats (old)", func() { test(oldEndpoint, true, false) }) + s.Run("saved query with stats (old)", func() { test(oldEndpoint, true, true) }) + s.Run("not saved query", func() { test(oneQueryEndpoint, false, true) }) + s.Run("saved query without stats", func() { test(oneQueryEndpoint, true, false) }) + s.Run("saved query with stats", func() { test(oneQueryEndpoint, true, true) }) + s.Run("custom query", func() { test(customQueryOneHostEndpoint, false, false) }) } func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() { @@ -336,7 +390,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() { hostDistributedQueryPrefix + cid2: "some other msg", }, Stats: map[string]*fleet.Stats{ - hostDistributedQueryPrefix + cid1: &fleet.Stats{ + hostDistributedQueryPrefix + cid1: { UserTime: uint64(1), SystemTime: uint64(2), }, @@ -452,7 +506,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() { hostDistributedQueryPrefix + cid2: "some other msg", }, Stats: map[string]*fleet.Stats{ - hostDistributedQueryPrefix + cid1: &fleet.Stats{ + hostDistributedQueryPrefix + cid1: { UserTime: uint64(1), SystemTime: uint64(2), }, @@ -701,7 +755,6 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() { oneLiveQueryResp := runOneLiveQueryResponse{} s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", 999), oneLiveQueryRequest, http.StatusNotFound, &oneLiveQueryResp) assert.Equal(t, 0, oneLiveQueryResp.RespondedHostCount) - } func (s *liveQueriesTestSuite) TestLiveQueriesRestInvalidHost() { @@ -735,7 +788,6 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestInvalidHost() { } oneLiveQueryResp := runOneLiveQueryResponse{} s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp) - } func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() { diff --git a/server/service/live_queries.go b/server/service/live_queries.go index 6bb10308e1..811eb5122b 100644 --- a/server/service/live_queries.go +++ b/server/service/live_queries.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "os" "strconv" @@ -27,6 +28,16 @@ type runOneLiveQueryRequest struct { HostIDs []uint `json:"host_ids"` } +type runLiveQueryOnHostRequest struct { + Identifier string `url:"identifier"` + Query string `json:"query"` +} + +type runLiveQueryOnHostByIDRequest struct { + HostID uint `url:"id"` + Query string `json:"query"` +} + type summaryPayload struct { TargetedHostCount int `json:"targeted_host_count"` RespondedHostCount int `json:"responded_host_count"` @@ -51,13 +62,23 @@ type runOneLiveQueryResponse struct { func (r runOneLiveQueryResponse) error() error { return r.Err } +type runLiveQueryOnHostResponse struct { + HostID uint `json:"host_id"` + Rows []map[string]string `json:"rows"` + Query string `json:"query"` + Status fleet.HostStatus `json:"status"` + Error string `json:"error,omitempty"` +} + +func (r runLiveQueryOnHostResponse) error() error { return nil } + func runOneLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) { req := request.(*runOneLiveQueryRequest) // Only allow a host to be specified once in HostIDs hostIDs := server.RemoveDuplicatesFromSlice(req.HostIDs) - campaignResults, respondedHostCount, err := runLiveQuery(ctx, svc, []uint{req.QueryID}, hostIDs) + campaignResults, respondedHostCount, err := runLiveQuery(ctx, svc, []uint{req.QueryID}, "", hostIDs) if err != nil { return nil, err } @@ -89,7 +110,7 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se // Only allow a host to be specified once in HostIDs hostIDs := server.RemoveDuplicatesFromSlice(req.HostIDs) - queryResults, respondedHostCount, err := runLiveQuery(ctx, svc, queryIDs, hostIDs) + queryResults, respondedHostCount, err := runLiveQuery(ctx, svc, queryIDs, "", hostIDs) if err != nil { return nil, err } @@ -104,7 +125,84 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se return res, nil } -func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, hostIDs []uint) ( +func runLiveQueryOnHostEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) { + req := request.(*runLiveQueryOnHostRequest) + + if req.Query == "" { + return nil, ctxerr.Wrap(ctx, badRequest("query is required")) + } + + host, err := svc.HostLiteByIdentifier(ctx, req.Identifier) + if err != nil { + return nil, ctxerr.Wrap(ctx, badRequest(fmt.Sprintf("host not found: %s: %s", req.Identifier, err.Error()))) + } + + return runLiveQueryOnHost(svc, ctx, host, req.Query) +} + +func runLiveQueryOnHostByIDEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) { + req := request.(*runLiveQueryOnHostByIDRequest) + + if req.Query == "" { + return nil, ctxerr.Wrap(ctx, badRequest("query is required")) + } + + host, err := svc.HostLiteByID(ctx, req.HostID) + if err != nil { + return nil, ctxerr.Wrap(ctx, badRequest(fmt.Sprintf("host not found: %d: %s", req.HostID, err.Error()))) + } + + return runLiveQueryOnHost(svc, ctx, host, req.Query) +} + +func runLiveQueryOnHost(svc fleet.Service, ctx context.Context, host *fleet.HostLite, query string) (errorer, error) { + res := runLiveQueryOnHostResponse{ + HostID: host.ID, + Query: query, + } + + status := (&fleet.Host{ + DistributedInterval: host.DistributedInterval, + ConfigTLSRefresh: host.ConfigTLSRefresh, + SeenTime: host.SeenTime, + }).Status(time.Now()) + switch status { + case fleet.StatusOnline, fleet.StatusNew: + res.Status = fleet.StatusOnline + case fleet.StatusOffline, fleet.StatusMIA, fleet.StatusMissing: + res.Status = fleet.StatusOffline + return res, nil + default: + return nil, fmt.Errorf("unknown host status: %s", status) + } + + queryResults, _, err := runLiveQuery(ctx, svc, []uint{0}, query, []uint{host.ID}) + if err != nil { + return nil, err + } + + if len(queryResults) > 0 { + var err error + if queryResults[0].Err != nil { + err = queryResults[0].Err + } else if len(queryResults[0].Results) > 0 { + queryResult := queryResults[0].Results[0] + if queryResult.Error != nil { + err = errors.New(*queryResult.Error) + } + res.Rows = queryResult.Rows + res.HostID = queryResult.HostID + } else { // timeout waiting for results + err = errors.New("timeout waiting for results") + } + if err != nil { + res.Error = err.Error() + } + } + return res, nil +} + +func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, query string, hostIDs []uint) ( []fleet.QueryCampaignResult, int, error, ) { // The period used here should always be less than the request timeout for any load @@ -119,7 +217,7 @@ func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, hostI logging.WithExtras(ctx, "live_query_rest_period_err", err) } - queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, queryIDs, hostIDs, duration) + queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, queryIDs, query, hostIDs, duration) if err != nil { return nil, 0, err } @@ -142,7 +240,7 @@ func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, hostI } func (svc *Service) RunLiveQueryDeadline( - ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration, + ctx context.Context, queryIDs []uint, query string, hostIDs []uint, deadline time.Duration, ) ([]fleet.QueryCampaignResult, int, error) { if len(queryIDs) == 0 || len(hostIDs) == 0 { svc.authz.SkipAuthorization(ctx) @@ -160,11 +258,19 @@ func (svc *Service) RunLiveQueryDeadline( wg.Add(1) go func() { defer wg.Done() - campaign, err := svc.NewDistributedQueryCampaign(ctx, "", &queryID, fleet.HostTargets{HostIDs: hostIDs}) + queryIDPtr := &queryID + queryString := "" + // 0 is a special ID that indicates we should use raw SQL query instead + if queryID == 0 { + queryIDPtr = nil + queryString = query + } + campaign, err := svc.NewDistributedQueryCampaign(ctx, queryString, queryIDPtr, fleet.HostTargets{HostIDs: hostIDs}) if err != nil { resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error()), Err: err} return } + queryID = campaign.QueryID readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign) if err != nil {