New live query API endpoint for custom query SQL. (#16810)

#16805 

- [X] Changes file added for user-visible changes in `changes/` or
`orbit/changes/`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- [X] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [X] Added/updated tests
- [X] Manual QA for all new/changed functionality

---------

Co-authored-by: Lucas Rodriguez <lucas@fleetdm.com>
This commit is contained in:
Victor Lyuboslavsky 2024-02-13 22:45:07 -06:00 committed by GitHub
parent 52c0d317db
commit 967eddcb37
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 335 additions and 28 deletions

View file

@ -0,0 +1 @@
* Add two new API endpoints to run a live query SQL on one host: `POST /api/latest/fleet/hosts/identifier/{identifier}/query` and `POST /api/_version_/fleet/hosts/{id}/query`.

View file

@ -4905,3 +4905,61 @@ func (ds *Datastore) GetHostHealth(ctx context.Context, id uint) (*fleet.HostHea
return &hh, nil
}
func (ds *Datastore) HostLiteByIdentifier(ctx context.Context, identifier string) (*fleet.HostLite, error) {
return ds.loadHostLite(ctx, nil, &identifier)
}
func (ds *Datastore) HostLiteByID(ctx context.Context, id uint) (*fleet.HostLite, error) {
return ds.loadHostLite(ctx, &id, nil)
}
func (ds *Datastore) loadHostLite(ctx context.Context, id *uint, identifier *string) (*fleet.HostLite, error) {
if id == nil && identifier == nil {
return nil, errors.New("must set one of id or identifier")
}
if id != nil && identifier != nil {
return nil, errors.New("cannot set both id and identifier")
}
stmt := `
SELECT
h.id,
h.team_id,
h.osquery_host_id,
h.node_key,
h.hostname,
h.uuid,
h.hardware_serial,
h.distributed_interval,
h.config_tls_refresh,
COALESCE(hst.seen_time, h.created_at) AS seen_time
FROM hosts h
LEFT JOIN host_seen_times hst ON (h.id = hst.host_id)
%s
LIMIT 1
`
var (
arg interface{}
whereClause string
)
if identifier != nil {
whereClause = "WHERE ? IN (h.hostname, h.osquery_host_id, h.node_key, h.uuid, h.hardware_serial)"
arg = identifier
} else {
whereClause = "WHERE id = ?"
arg = id
}
host := &fleet.HostLite{}
err := sqlx.GetContext(ctx, ds.reader(ctx), host, fmt.Sprintf(stmt, whereClause), arg)
if err != nil {
if err == sql.ErrNoRows {
if identifier != nil {
return nil, ctxerr.Wrap(ctx, notFound("Host").WithName(*identifier))
}
return nil, ctxerr.Wrap(ctx, notFound("Host").WithID(*id))
}
return nil, ctxerr.Wrap(ctx, err, "get host lite")
}
return host, nil
}

View file

@ -258,6 +258,11 @@ type Datastore interface {
// HostByIdentifier returns one host matching the provided identifier. Possible matches can be on
// osquery_host_identifier, node_key, UUID, or hostname.
HostByIdentifier(ctx context.Context, identifier string) (*Host, error)
// HostLiteByIdentifier returns a host and a subset of its fields using an "identifier" string.
// The identifier string will be matched against the hostname, osquery_host_id, node_key, uuid and hardware_serial columns.
HostLiteByIdentifier(ctx context.Context, identifier string) (*HostLite, error)
// HostLiteByIdentifier returns a host and a subset of its fields from its id.
HostLiteByID(ctx context.Context, id uint) (*HostLite, error)
// AddHostsToTeam adds hosts to an existing team, clearing their team settings if teamID is nil.
AddHostsToTeam(ctx context.Context, teamID *uint, hostIDs []uint) error

View file

@ -1201,3 +1201,17 @@ type HostMacOSProfile struct {
// InstallDate is the date the profile was installed on the host as reported by the host's clock.
InstallDate time.Time `json:"install_date" db:"install_date"`
}
// HostLite contains a subset of Host fields.
type HostLite struct {
ID uint `db:"id"`
TeamID *uint `db:"team_id"`
Hostname string `db:"hostname"`
OsqueryHostID string `db:"osquery_host_id"`
NodeKey string `db:"node_key"`
UUID string `db:"uuid"`
HardwareSerial string `db:"hardware_serial"`
SeenTime time.Time `db:"seen_time"`
DistributedInterval uint `db:"distributed_interval"`
ConfigTLSRefresh uint `db:"config_tls_refresh"`
}

View file

@ -308,7 +308,9 @@ type Service interface {
GetCampaignReader(ctx context.Context, campaign *DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error)
CompleteCampaign(ctx context.Context, campaign *DistributedQueryCampaign) error
RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]QueryCampaignResult, int, error)
RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, query string, hostIDs []uint, deadline time.Duration) (
[]QueryCampaignResult, int, error,
)
// /////////////////////////////////////////////////////////////////////////////
// AgentOptionsService
@ -364,6 +366,11 @@ type Service interface {
// device-authenticated API), or manually by the user (via the
// user-authenticated API).
SetCustomHostDeviceMapping(ctx context.Context, hostID uint, email string) ([]*HostDeviceMapping, error)
// HostLiteByIdentifier returns a host and a subset of its fields using an "identifier" string.
// The identifier string will be matched against the Hostname, OsqueryHostID, NodeKey, UUID and HardwareSerial fields.
HostLiteByIdentifier(ctx context.Context, identifier string) (*HostLite, error)
// HostLiteByIdentifier returns a host and a subset of its fields from its id.
HostLiteByID(ctx context.Context, id uint) (*HostLite, error)
// ListDevicePolicies lists all policies for the given host, including passing / failing summaries
ListDevicePolicies(ctx context.Context, host *Host) ([]*HostPolicy, error)

View file

@ -192,6 +192,10 @@ type HostIDsByOSVersionFunc func(ctx context.Context, osVersion fleet.OSVersion,
type HostByIdentifierFunc func(ctx context.Context, identifier string) (*fleet.Host, error)
type HostLiteByIdentifierFunc func(ctx context.Context, identifier string) (*fleet.HostLite, error)
type HostLiteByIDFunc func(ctx context.Context, id uint) (*fleet.HostLite, error)
type AddHostsToTeamFunc func(ctx context.Context, teamID *uint, hostIDs []uint) error
type TotalAndUnseenHostsSinceFunc func(ctx context.Context, daysCount int) (total int, unseen int, err error)
@ -1078,6 +1082,12 @@ type DataStore struct {
HostByIdentifierFunc HostByIdentifierFunc
HostByIdentifierFuncInvoked bool
HostLiteByIdentifierFunc HostLiteByIdentifierFunc
HostLiteByIdentifierFuncInvoked bool
HostLiteByIDFunc HostLiteByIDFunc
HostLiteByIDFuncInvoked bool
AddHostsToTeamFunc AddHostsToTeamFunc
AddHostsToTeamFuncInvoked bool
@ -2626,6 +2636,20 @@ func (s *DataStore) HostByIdentifier(ctx context.Context, identifier string) (*f
return s.HostByIdentifierFunc(ctx, identifier)
}
func (s *DataStore) HostLiteByIdentifier(ctx context.Context, identifier string) (*fleet.HostLite, error) {
s.mu.Lock()
s.HostLiteByIdentifierFuncInvoked = true
s.mu.Unlock()
return s.HostLiteByIdentifierFunc(ctx, identifier)
}
func (s *DataStore) HostLiteByID(ctx context.Context, id uint) (*fleet.HostLite, error) {
s.mu.Lock()
s.HostLiteByIDFuncInvoked = true
s.mu.Unlock()
return s.HostLiteByIDFunc(ctx, id)
}
func (s *DataStore) AddHostsToTeam(ctx context.Context, teamID *uint, hostIDs []uint) error {
s.mu.Lock()
s.AddHostsToTeamFuncInvoked = true

View file

@ -378,6 +378,8 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
ue.GET("/api/_version_/fleet/hosts/count", countHostsEndpoint, countHostsRequest{})
ue.POST("/api/_version_/fleet/hosts/search", searchHostsEndpoint, searchHostsRequest{})
ue.GET("/api/_version_/fleet/hosts/identifier/{identifier}", hostByIdentifierEndpoint, hostByIdentifierRequest{})
ue.POST("/api/_version_/fleet/hosts/identifier/{identifier}/query", runLiveQueryOnHostEndpoint, runLiveQueryOnHostRequest{})
ue.POST("/api/_version_/fleet/hosts/{id:[0-9]+}/query", runLiveQueryOnHostByIDEndpoint, runLiveQueryOnHostByIDRequest{})
ue.DELETE("/api/_version_/fleet/hosts/{id:[0-9]+}", deleteHostEndpoint, deleteHostRequest{})
ue.POST("/api/_version_/fleet/hosts/transfer", addHostsToTeamEndpoint, addHostsToTeamRequest{})
ue.POST("/api/_version_/fleet/hosts/transfer/filter", addHostsToTeamByFilterEndpoint, addHostsToTeamByFilterRequest{})

View file

@ -2097,3 +2097,41 @@ func (svc *Service) GetHostHealth(ctx context.Context, id uint) (*fleet.HostHeal
return hh, nil
}
func (svc *Service) HostLiteByIdentifier(ctx context.Context, identifier string) (*fleet.HostLite, error) {
if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil {
return nil, err
}
host, err := svc.ds.HostLiteByIdentifier(ctx, identifier)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get host by identifier")
}
if err := svc.authz.Authorize(ctx, fleet.Host{
TeamID: host.TeamID,
}, fleet.ActionRead); err != nil {
return nil, err
}
return host, nil
}
func (svc *Service) HostLiteByID(ctx context.Context, id uint) (*fleet.HostLite, error) {
if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil {
return nil, err
}
host, err := svc.ds.HostLiteByID(ctx, id)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get host by id")
}
if err := svc.authz.Authorize(ctx, fleet.Host{
TeamID: host.TeamID,
}, fleet.ActionRead); err != nil {
return nil, err
}
return host, nil
}

View file

@ -6,7 +6,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/jmoiron/sqlx"
"math"
"math/rand"
"net/http"
@ -23,6 +22,7 @@ import (
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/fleetdm/fleet/v4/server/pubsub"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
@ -86,8 +86,16 @@ func (s *liveQueriesTestSuite) TearDownTest() {
s.lq.Mock = mock.Mock{}
}
type liveQueryEndpoint int
const (
oldEndpoint liveQueryEndpoint = iota
oneQueryEndpoint
customQueryOneHostEndpoint
)
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
test := func(newEndpoint bool, savedQuery bool, hasStats bool) {
test := func(endpoint liveQueryEndpoint, savedQuery bool, hasStats bool) {
t := s.T()
host := s.hosts[0]
@ -114,7 +122,8 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
oneLiveQueryResp := runOneLiveQueryResponse{}
liveQueryResp := runLiveQueryResponse{}
if newEndpoint {
liveQueryOnHostResp := runLiveQueryOnHostResponse{}
if endpoint == oneQueryEndpoint {
liveQueryRequest := runOneLiveQueryRequest{
HostIDs: []uint{host.ID},
}
@ -122,7 +131,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
defer wg.Done()
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), liveQueryRequest, http.StatusOK, &oneLiveQueryResp)
}()
} else {
} else if endpoint == oldEndpoint {
liveQueryRequest := runLiveQueryRequest{
QueryIDs: []uint{q1.ID},
HostIDs: []uint{host.ID},
@ -131,13 +140,48 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
defer wg.Done()
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
}()
} else { // customQueryOneHostEndpoint
liveQueryRequest := runLiveQueryOnHostRequest{
Query: query,
}
go func() {
defer wg.Done()
s.DoJSON(
"POST", fmt.Sprintf("/api/latest/fleet/hosts/identifier/%s/query", host.UUID), liveQueryRequest, http.StatusOK,
&liveQueryOnHostResp,
)
}()
}
// For loop, waiting for campaign to be created.
var cid string
cidChannel := make(chan string)
go func() {
for {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for range ticker.C {
if endpoint == customQueryOneHostEndpoint {
campaign := fleet.DistributedQueryCampaign{}
err := mysql.ExecAdhocSQLWithError(
s.ds, func(q sqlx.ExtContext) error {
return sqlx.GetContext(
context.Background(), q, &campaign,
`SELECT * FROM distributed_query_campaigns WHERE status = ? ORDER BY id DESC LIMIT 1`,
fleet.QueryRunning,
)
},
)
if errors.Is(err, sql.ErrNoRows) {
continue
}
if err != nil {
t.Error("Error selecting from distributed_query_campaigns", err)
return
}
q1.ID = campaign.QueryID
cidChannel <- fmt.Sprint(campaign.ID)
return
}
campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID)
require.NoError(t, err)
@ -171,9 +215,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
hostDistributedQueryPrefix + cid: 0,
hostDistributedQueryPrefix + "9999": "0",
},
Messages: map[string]string{
hostDistributedQueryPrefix + cid: "some msg",
},
Messages: map[string]string{},
Stats: map[string]*fleet.Stats{
hostDistributedQueryPrefix + cid: stats,
},
@ -184,19 +226,28 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
wg.Wait()
var result fleet.QueryResult
if newEndpoint {
if endpoint == oneQueryEndpoint {
assert.Equal(t, q1.ID, oneLiveQueryResp.QueryID)
assert.Equal(t, 1, oneLiveQueryResp.TargetedHostCount)
assert.Equal(t, 1, oneLiveQueryResp.RespondedHostCount)
require.Len(t, oneLiveQueryResp.Results, 1)
result = oneLiveQueryResp.Results[0]
} else {
} else if endpoint == oldEndpoint {
require.Len(t, liveQueryResp.Results, 1)
assert.Equal(t, 1, liveQueryResp.Summary.TargetedHostCount)
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
require.Len(t, liveQueryResp.Results[0].Results, 1)
result = liveQueryResp.Results[0].Results[0]
} else { // customQueryOneHostEndpoint
assert.Empty(t, liveQueryOnHostResp.Error)
assert.Equal(t, host.ID, liveQueryOnHostResp.HostID)
assert.Equal(t, fleet.StatusOnline, liveQueryOnHostResp.Status)
assert.Equal(t, query, liveQueryOnHostResp.Query)
result = fleet.QueryResult{
HostID: liveQueryOnHostResp.HostID,
Rows: liveQueryOnHostResp.Rows,
}
}
assert.Equal(t, host.ID, result.HostID)
require.Len(t, result.Rows, 1)
@ -207,7 +258,9 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
var activity *fleet.ActivityTypeLiveQuery
activityUpdated := make(chan *fleet.ActivityTypeLiveQuery)
go func() {
for {
ticker := time.NewTicker(100 * time.Millisecond)
defer ticker.Stop()
for range ticker.C {
details := json.RawMessage{}
err := mysql.ExecAdhocSQLWithError(
s.ds, func(q sqlx.ExtContext) error {
@ -262,12 +315,13 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
}
}
}
s.Run("not saved query (old)", func() { test(false, false, true) })
s.Run("saved query without stats (old)", func() { test(false, true, false) })
s.Run("saved query with stats (old)", func() { test(false, true, true) })
s.Run("not saved query", func() { test(true, false, true) })
s.Run("saved query without stats", func() { test(true, true, false) })
s.Run("saved query with stats", func() { test(true, true, true) })
s.Run("not saved query (old)", func() { test(oldEndpoint, false, true) })
s.Run("saved query without stats (old)", func() { test(oldEndpoint, true, false) })
s.Run("saved query with stats (old)", func() { test(oldEndpoint, true, true) })
s.Run("not saved query", func() { test(oneQueryEndpoint, false, true) })
s.Run("saved query without stats", func() { test(oneQueryEndpoint, true, false) })
s.Run("saved query with stats", func() { test(oneQueryEndpoint, true, true) })
s.Run("custom query", func() { test(customQueryOneHostEndpoint, false, false) })
}
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
@ -336,7 +390,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
hostDistributedQueryPrefix + cid2: "some other msg",
},
Stats: map[string]*fleet.Stats{
hostDistributedQueryPrefix + cid1: &fleet.Stats{
hostDistributedQueryPrefix + cid1: {
UserTime: uint64(1),
SystemTime: uint64(2),
},
@ -452,7 +506,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() {
hostDistributedQueryPrefix + cid2: "some other msg",
},
Stats: map[string]*fleet.Stats{
hostDistributedQueryPrefix + cid1: &fleet.Stats{
hostDistributedQueryPrefix + cid1: {
UserTime: uint64(1),
SystemTime: uint64(2),
},
@ -701,7 +755,6 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() {
oneLiveQueryResp := runOneLiveQueryResponse{}
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", 999), oneLiveQueryRequest, http.StatusNotFound, &oneLiveQueryResp)
assert.Equal(t, 0, oneLiveQueryResp.RespondedHostCount)
}
func (s *liveQueriesTestSuite) TestLiveQueriesRestInvalidHost() {
@ -735,7 +788,6 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestInvalidHost() {
}
oneLiveQueryResp := runOneLiveQueryResponse{}
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp)
}
func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() {

View file

@ -2,6 +2,7 @@ package service
import (
"context"
"errors"
"fmt"
"os"
"strconv"
@ -27,6 +28,16 @@ type runOneLiveQueryRequest struct {
HostIDs []uint `json:"host_ids"`
}
type runLiveQueryOnHostRequest struct {
Identifier string `url:"identifier"`
Query string `json:"query"`
}
type runLiveQueryOnHostByIDRequest struct {
HostID uint `url:"id"`
Query string `json:"query"`
}
type summaryPayload struct {
TargetedHostCount int `json:"targeted_host_count"`
RespondedHostCount int `json:"responded_host_count"`
@ -51,13 +62,23 @@ type runOneLiveQueryResponse struct {
func (r runOneLiveQueryResponse) error() error { return r.Err }
type runLiveQueryOnHostResponse struct {
HostID uint `json:"host_id"`
Rows []map[string]string `json:"rows"`
Query string `json:"query"`
Status fleet.HostStatus `json:"status"`
Error string `json:"error,omitempty"`
}
func (r runLiveQueryOnHostResponse) error() error { return nil }
func runOneLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
req := request.(*runOneLiveQueryRequest)
// Only allow a host to be specified once in HostIDs
hostIDs := server.RemoveDuplicatesFromSlice(req.HostIDs)
campaignResults, respondedHostCount, err := runLiveQuery(ctx, svc, []uint{req.QueryID}, hostIDs)
campaignResults, respondedHostCount, err := runLiveQuery(ctx, svc, []uint{req.QueryID}, "", hostIDs)
if err != nil {
return nil, err
}
@ -89,7 +110,7 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
// Only allow a host to be specified once in HostIDs
hostIDs := server.RemoveDuplicatesFromSlice(req.HostIDs)
queryResults, respondedHostCount, err := runLiveQuery(ctx, svc, queryIDs, hostIDs)
queryResults, respondedHostCount, err := runLiveQuery(ctx, svc, queryIDs, "", hostIDs)
if err != nil {
return nil, err
}
@ -104,7 +125,84 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
return res, nil
}
func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, hostIDs []uint) (
func runLiveQueryOnHostEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
req := request.(*runLiveQueryOnHostRequest)
if req.Query == "" {
return nil, ctxerr.Wrap(ctx, badRequest("query is required"))
}
host, err := svc.HostLiteByIdentifier(ctx, req.Identifier)
if err != nil {
return nil, ctxerr.Wrap(ctx, badRequest(fmt.Sprintf("host not found: %s: %s", req.Identifier, err.Error())))
}
return runLiveQueryOnHost(svc, ctx, host, req.Query)
}
func runLiveQueryOnHostByIDEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
req := request.(*runLiveQueryOnHostByIDRequest)
if req.Query == "" {
return nil, ctxerr.Wrap(ctx, badRequest("query is required"))
}
host, err := svc.HostLiteByID(ctx, req.HostID)
if err != nil {
return nil, ctxerr.Wrap(ctx, badRequest(fmt.Sprintf("host not found: %d: %s", req.HostID, err.Error())))
}
return runLiveQueryOnHost(svc, ctx, host, req.Query)
}
func runLiveQueryOnHost(svc fleet.Service, ctx context.Context, host *fleet.HostLite, query string) (errorer, error) {
res := runLiveQueryOnHostResponse{
HostID: host.ID,
Query: query,
}
status := (&fleet.Host{
DistributedInterval: host.DistributedInterval,
ConfigTLSRefresh: host.ConfigTLSRefresh,
SeenTime: host.SeenTime,
}).Status(time.Now())
switch status {
case fleet.StatusOnline, fleet.StatusNew:
res.Status = fleet.StatusOnline
case fleet.StatusOffline, fleet.StatusMIA, fleet.StatusMissing:
res.Status = fleet.StatusOffline
return res, nil
default:
return nil, fmt.Errorf("unknown host status: %s", status)
}
queryResults, _, err := runLiveQuery(ctx, svc, []uint{0}, query, []uint{host.ID})
if err != nil {
return nil, err
}
if len(queryResults) > 0 {
var err error
if queryResults[0].Err != nil {
err = queryResults[0].Err
} else if len(queryResults[0].Results) > 0 {
queryResult := queryResults[0].Results[0]
if queryResult.Error != nil {
err = errors.New(*queryResult.Error)
}
res.Rows = queryResult.Rows
res.HostID = queryResult.HostID
} else { // timeout waiting for results
err = errors.New("timeout waiting for results")
}
if err != nil {
res.Error = err.Error()
}
}
return res, nil
}
func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, query string, hostIDs []uint) (
[]fleet.QueryCampaignResult, int, error,
) {
// The period used here should always be less than the request timeout for any load
@ -119,7 +217,7 @@ func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, hostI
logging.WithExtras(ctx, "live_query_rest_period_err", err)
}
queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, queryIDs, hostIDs, duration)
queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, queryIDs, query, hostIDs, duration)
if err != nil {
return nil, 0, err
}
@ -142,7 +240,7 @@ func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, hostI
}
func (svc *Service) RunLiveQueryDeadline(
ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration,
ctx context.Context, queryIDs []uint, query string, hostIDs []uint, deadline time.Duration,
) ([]fleet.QueryCampaignResult, int, error) {
if len(queryIDs) == 0 || len(hostIDs) == 0 {
svc.authz.SkipAuthorization(ctx)
@ -160,11 +258,19 @@ func (svc *Service) RunLiveQueryDeadline(
wg.Add(1)
go func() {
defer wg.Done()
campaign, err := svc.NewDistributedQueryCampaign(ctx, "", &queryID, fleet.HostTargets{HostIDs: hostIDs})
queryIDPtr := &queryID
queryString := ""
// 0 is a special ID that indicates we should use raw SQL query instead
if queryID == 0 {
queryIDPtr = nil
queryString = query
}
campaign, err := svc.NewDistributedQueryCampaign(ctx, queryString, queryIDPtr, fleet.HostTargets{HostIDs: hostIDs})
if err != nil {
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error()), Err: err}
return
}
queryID = campaign.QueryID
readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign)
if err != nil {