mirror of
https://github.com/fleetdm/fleet
synced 2026-05-21 07:58:31 +00:00
New live query API endpoint for custom query SQL. (#16810)
#16805 - [X] Changes file added for user-visible changes in `changes/` or `orbit/changes/`. See [Changes files](https://fleetdm.com/docs/contributing/committing-changes#changes-files) for more information. - [X] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) - [X] Added/updated tests - [X] Manual QA for all new/changed functionality --------- Co-authored-by: Lucas Rodriguez <lucas@fleetdm.com>
This commit is contained in:
parent
52c0d317db
commit
967eddcb37
10 changed files with 335 additions and 28 deletions
1
changes/16805-new-live-query-on-host-endpoint
Normal file
1
changes/16805-new-live-query-on-host-endpoint
Normal file
|
|
@ -0,0 +1 @@
|
|||
* Add two new API endpoints to run a live query SQL on one host: `POST /api/latest/fleet/hosts/identifier/{identifier}/query` and `POST /api/_version_/fleet/hosts/{id}/query`.
|
||||
|
|
@ -4905,3 +4905,61 @@ func (ds *Datastore) GetHostHealth(ctx context.Context, id uint) (*fleet.HostHea
|
|||
|
||||
return &hh, nil
|
||||
}
|
||||
|
||||
func (ds *Datastore) HostLiteByIdentifier(ctx context.Context, identifier string) (*fleet.HostLite, error) {
|
||||
return ds.loadHostLite(ctx, nil, &identifier)
|
||||
}
|
||||
|
||||
func (ds *Datastore) HostLiteByID(ctx context.Context, id uint) (*fleet.HostLite, error) {
|
||||
return ds.loadHostLite(ctx, &id, nil)
|
||||
}
|
||||
|
||||
func (ds *Datastore) loadHostLite(ctx context.Context, id *uint, identifier *string) (*fleet.HostLite, error) {
|
||||
if id == nil && identifier == nil {
|
||||
return nil, errors.New("must set one of id or identifier")
|
||||
}
|
||||
if id != nil && identifier != nil {
|
||||
return nil, errors.New("cannot set both id and identifier")
|
||||
}
|
||||
stmt := `
|
||||
SELECT
|
||||
h.id,
|
||||
h.team_id,
|
||||
h.osquery_host_id,
|
||||
h.node_key,
|
||||
h.hostname,
|
||||
h.uuid,
|
||||
h.hardware_serial,
|
||||
h.distributed_interval,
|
||||
h.config_tls_refresh,
|
||||
COALESCE(hst.seen_time, h.created_at) AS seen_time
|
||||
FROM hosts h
|
||||
LEFT JOIN host_seen_times hst ON (h.id = hst.host_id)
|
||||
%s
|
||||
LIMIT 1
|
||||
`
|
||||
var (
|
||||
arg interface{}
|
||||
whereClause string
|
||||
)
|
||||
if identifier != nil {
|
||||
whereClause = "WHERE ? IN (h.hostname, h.osquery_host_id, h.node_key, h.uuid, h.hardware_serial)"
|
||||
arg = identifier
|
||||
} else {
|
||||
whereClause = "WHERE id = ?"
|
||||
arg = id
|
||||
}
|
||||
host := &fleet.HostLite{}
|
||||
err := sqlx.GetContext(ctx, ds.reader(ctx), host, fmt.Sprintf(stmt, whereClause), arg)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
if identifier != nil {
|
||||
return nil, ctxerr.Wrap(ctx, notFound("Host").WithName(*identifier))
|
||||
}
|
||||
return nil, ctxerr.Wrap(ctx, notFound("Host").WithID(*id))
|
||||
}
|
||||
return nil, ctxerr.Wrap(ctx, err, "get host lite")
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -258,6 +258,11 @@ type Datastore interface {
|
|||
// HostByIdentifier returns one host matching the provided identifier. Possible matches can be on
|
||||
// osquery_host_identifier, node_key, UUID, or hostname.
|
||||
HostByIdentifier(ctx context.Context, identifier string) (*Host, error)
|
||||
// HostLiteByIdentifier returns a host and a subset of its fields using an "identifier" string.
|
||||
// The identifier string will be matched against the hostname, osquery_host_id, node_key, uuid and hardware_serial columns.
|
||||
HostLiteByIdentifier(ctx context.Context, identifier string) (*HostLite, error)
|
||||
// HostLiteByIdentifier returns a host and a subset of its fields from its id.
|
||||
HostLiteByID(ctx context.Context, id uint) (*HostLite, error)
|
||||
// AddHostsToTeam adds hosts to an existing team, clearing their team settings if teamID is nil.
|
||||
AddHostsToTeam(ctx context.Context, teamID *uint, hostIDs []uint) error
|
||||
|
||||
|
|
|
|||
|
|
@ -1201,3 +1201,17 @@ type HostMacOSProfile struct {
|
|||
// InstallDate is the date the profile was installed on the host as reported by the host's clock.
|
||||
InstallDate time.Time `json:"install_date" db:"install_date"`
|
||||
}
|
||||
|
||||
// HostLite contains a subset of Host fields.
|
||||
type HostLite struct {
|
||||
ID uint `db:"id"`
|
||||
TeamID *uint `db:"team_id"`
|
||||
Hostname string `db:"hostname"`
|
||||
OsqueryHostID string `db:"osquery_host_id"`
|
||||
NodeKey string `db:"node_key"`
|
||||
UUID string `db:"uuid"`
|
||||
HardwareSerial string `db:"hardware_serial"`
|
||||
SeenTime time.Time `db:"seen_time"`
|
||||
DistributedInterval uint `db:"distributed_interval"`
|
||||
ConfigTLSRefresh uint `db:"config_tls_refresh"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -308,7 +308,9 @@ type Service interface {
|
|||
|
||||
GetCampaignReader(ctx context.Context, campaign *DistributedQueryCampaign) (<-chan interface{}, context.CancelFunc, error)
|
||||
CompleteCampaign(ctx context.Context, campaign *DistributedQueryCampaign) error
|
||||
RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration) ([]QueryCampaignResult, int, error)
|
||||
RunLiveQueryDeadline(ctx context.Context, queryIDs []uint, query string, hostIDs []uint, deadline time.Duration) (
|
||||
[]QueryCampaignResult, int, error,
|
||||
)
|
||||
|
||||
// /////////////////////////////////////////////////////////////////////////////
|
||||
// AgentOptionsService
|
||||
|
|
@ -364,6 +366,11 @@ type Service interface {
|
|||
// device-authenticated API), or manually by the user (via the
|
||||
// user-authenticated API).
|
||||
SetCustomHostDeviceMapping(ctx context.Context, hostID uint, email string) ([]*HostDeviceMapping, error)
|
||||
// HostLiteByIdentifier returns a host and a subset of its fields using an "identifier" string.
|
||||
// The identifier string will be matched against the Hostname, OsqueryHostID, NodeKey, UUID and HardwareSerial fields.
|
||||
HostLiteByIdentifier(ctx context.Context, identifier string) (*HostLite, error)
|
||||
// HostLiteByIdentifier returns a host and a subset of its fields from its id.
|
||||
HostLiteByID(ctx context.Context, id uint) (*HostLite, error)
|
||||
|
||||
// ListDevicePolicies lists all policies for the given host, including passing / failing summaries
|
||||
ListDevicePolicies(ctx context.Context, host *Host) ([]*HostPolicy, error)
|
||||
|
|
|
|||
|
|
@ -192,6 +192,10 @@ type HostIDsByOSVersionFunc func(ctx context.Context, osVersion fleet.OSVersion,
|
|||
|
||||
type HostByIdentifierFunc func(ctx context.Context, identifier string) (*fleet.Host, error)
|
||||
|
||||
type HostLiteByIdentifierFunc func(ctx context.Context, identifier string) (*fleet.HostLite, error)
|
||||
|
||||
type HostLiteByIDFunc func(ctx context.Context, id uint) (*fleet.HostLite, error)
|
||||
|
||||
type AddHostsToTeamFunc func(ctx context.Context, teamID *uint, hostIDs []uint) error
|
||||
|
||||
type TotalAndUnseenHostsSinceFunc func(ctx context.Context, daysCount int) (total int, unseen int, err error)
|
||||
|
|
@ -1078,6 +1082,12 @@ type DataStore struct {
|
|||
HostByIdentifierFunc HostByIdentifierFunc
|
||||
HostByIdentifierFuncInvoked bool
|
||||
|
||||
HostLiteByIdentifierFunc HostLiteByIdentifierFunc
|
||||
HostLiteByIdentifierFuncInvoked bool
|
||||
|
||||
HostLiteByIDFunc HostLiteByIDFunc
|
||||
HostLiteByIDFuncInvoked bool
|
||||
|
||||
AddHostsToTeamFunc AddHostsToTeamFunc
|
||||
AddHostsToTeamFuncInvoked bool
|
||||
|
||||
|
|
@ -2626,6 +2636,20 @@ func (s *DataStore) HostByIdentifier(ctx context.Context, identifier string) (*f
|
|||
return s.HostByIdentifierFunc(ctx, identifier)
|
||||
}
|
||||
|
||||
func (s *DataStore) HostLiteByIdentifier(ctx context.Context, identifier string) (*fleet.HostLite, error) {
|
||||
s.mu.Lock()
|
||||
s.HostLiteByIdentifierFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
return s.HostLiteByIdentifierFunc(ctx, identifier)
|
||||
}
|
||||
|
||||
func (s *DataStore) HostLiteByID(ctx context.Context, id uint) (*fleet.HostLite, error) {
|
||||
s.mu.Lock()
|
||||
s.HostLiteByIDFuncInvoked = true
|
||||
s.mu.Unlock()
|
||||
return s.HostLiteByIDFunc(ctx, id)
|
||||
}
|
||||
|
||||
func (s *DataStore) AddHostsToTeam(ctx context.Context, teamID *uint, hostIDs []uint) error {
|
||||
s.mu.Lock()
|
||||
s.AddHostsToTeamFuncInvoked = true
|
||||
|
|
|
|||
|
|
@ -378,6 +378,8 @@ func attachFleetAPIRoutes(r *mux.Router, svc fleet.Service, config config.FleetC
|
|||
ue.GET("/api/_version_/fleet/hosts/count", countHostsEndpoint, countHostsRequest{})
|
||||
ue.POST("/api/_version_/fleet/hosts/search", searchHostsEndpoint, searchHostsRequest{})
|
||||
ue.GET("/api/_version_/fleet/hosts/identifier/{identifier}", hostByIdentifierEndpoint, hostByIdentifierRequest{})
|
||||
ue.POST("/api/_version_/fleet/hosts/identifier/{identifier}/query", runLiveQueryOnHostEndpoint, runLiveQueryOnHostRequest{})
|
||||
ue.POST("/api/_version_/fleet/hosts/{id:[0-9]+}/query", runLiveQueryOnHostByIDEndpoint, runLiveQueryOnHostByIDRequest{})
|
||||
ue.DELETE("/api/_version_/fleet/hosts/{id:[0-9]+}", deleteHostEndpoint, deleteHostRequest{})
|
||||
ue.POST("/api/_version_/fleet/hosts/transfer", addHostsToTeamEndpoint, addHostsToTeamRequest{})
|
||||
ue.POST("/api/_version_/fleet/hosts/transfer/filter", addHostsToTeamByFilterEndpoint, addHostsToTeamByFilterRequest{})
|
||||
|
|
|
|||
|
|
@ -2097,3 +2097,41 @@ func (svc *Service) GetHostHealth(ctx context.Context, id uint) (*fleet.HostHeal
|
|||
|
||||
return hh, nil
|
||||
}
|
||||
|
||||
func (svc *Service) HostLiteByIdentifier(ctx context.Context, identifier string) (*fleet.HostLite, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
host, err := svc.ds.HostLiteByIdentifier(ctx, identifier)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "get host by identifier")
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, fleet.Host{
|
||||
TeamID: host.TeamID,
|
||||
}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
||||
func (svc *Service) HostLiteByID(ctx context.Context, id uint) (*fleet.HostLite, error) {
|
||||
if err := svc.authz.Authorize(ctx, &fleet.Host{}, fleet.ActionList); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
host, err := svc.ds.HostLiteByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, err, "get host by id")
|
||||
}
|
||||
|
||||
if err := svc.authz.Authorize(ctx, fleet.Host{
|
||||
TeamID: host.TeamID,
|
||||
}, fleet.ActionRead); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return host, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
|
|
@ -23,6 +22,7 @@ import (
|
|||
"github.com/fleetdm/fleet/v4/server/ptr"
|
||||
"github.com/fleetdm/fleet/v4/server/pubsub"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
|
@ -86,8 +86,16 @@ func (s *liveQueriesTestSuite) TearDownTest() {
|
|||
s.lq.Mock = mock.Mock{}
|
||||
}
|
||||
|
||||
type liveQueryEndpoint int
|
||||
|
||||
const (
|
||||
oldEndpoint liveQueryEndpoint = iota
|
||||
oneQueryEndpoint
|
||||
customQueryOneHostEndpoint
|
||||
)
|
||||
|
||||
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
||||
test := func(newEndpoint bool, savedQuery bool, hasStats bool) {
|
||||
test := func(endpoint liveQueryEndpoint, savedQuery bool, hasStats bool) {
|
||||
t := s.T()
|
||||
|
||||
host := s.hosts[0]
|
||||
|
|
@ -114,7 +122,8 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
|||
|
||||
oneLiveQueryResp := runOneLiveQueryResponse{}
|
||||
liveQueryResp := runLiveQueryResponse{}
|
||||
if newEndpoint {
|
||||
liveQueryOnHostResp := runLiveQueryOnHostResponse{}
|
||||
if endpoint == oneQueryEndpoint {
|
||||
liveQueryRequest := runOneLiveQueryRequest{
|
||||
HostIDs: []uint{host.ID},
|
||||
}
|
||||
|
|
@ -122,7 +131,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
|||
defer wg.Done()
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), liveQueryRequest, http.StatusOK, &oneLiveQueryResp)
|
||||
}()
|
||||
} else {
|
||||
} else if endpoint == oldEndpoint {
|
||||
liveQueryRequest := runLiveQueryRequest{
|
||||
QueryIDs: []uint{q1.ID},
|
||||
HostIDs: []uint{host.ID},
|
||||
|
|
@ -131,13 +140,48 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
|||
defer wg.Done()
|
||||
s.DoJSON("GET", "/api/latest/fleet/queries/run", liveQueryRequest, http.StatusOK, &liveQueryResp)
|
||||
}()
|
||||
} else { // customQueryOneHostEndpoint
|
||||
liveQueryRequest := runLiveQueryOnHostRequest{
|
||||
Query: query,
|
||||
}
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
s.DoJSON(
|
||||
"POST", fmt.Sprintf("/api/latest/fleet/hosts/identifier/%s/query", host.UUID), liveQueryRequest, http.StatusOK,
|
||||
&liveQueryOnHostResp,
|
||||
)
|
||||
}()
|
||||
}
|
||||
|
||||
// For loop, waiting for campaign to be created.
|
||||
var cid string
|
||||
cidChannel := make(chan string)
|
||||
go func() {
|
||||
for {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
if endpoint == customQueryOneHostEndpoint {
|
||||
campaign := fleet.DistributedQueryCampaign{}
|
||||
err := mysql.ExecAdhocSQLWithError(
|
||||
s.ds, func(q sqlx.ExtContext) error {
|
||||
return sqlx.GetContext(
|
||||
context.Background(), q, &campaign,
|
||||
`SELECT * FROM distributed_query_campaigns WHERE status = ? ORDER BY id DESC LIMIT 1`,
|
||||
fleet.QueryRunning,
|
||||
)
|
||||
},
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
t.Error("Error selecting from distributed_query_campaigns", err)
|
||||
return
|
||||
}
|
||||
q1.ID = campaign.QueryID
|
||||
cidChannel <- fmt.Sprint(campaign.ID)
|
||||
return
|
||||
}
|
||||
campaigns, err := s.ds.DistributedQueryCampaignsForQuery(context.Background(), q1.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -171,9 +215,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
|||
hostDistributedQueryPrefix + cid: 0,
|
||||
hostDistributedQueryPrefix + "9999": "0",
|
||||
},
|
||||
Messages: map[string]string{
|
||||
hostDistributedQueryPrefix + cid: "some msg",
|
||||
},
|
||||
Messages: map[string]string{},
|
||||
Stats: map[string]*fleet.Stats{
|
||||
hostDistributedQueryPrefix + cid: stats,
|
||||
},
|
||||
|
|
@ -184,19 +226,28 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
|||
wg.Wait()
|
||||
|
||||
var result fleet.QueryResult
|
||||
if newEndpoint {
|
||||
if endpoint == oneQueryEndpoint {
|
||||
assert.Equal(t, q1.ID, oneLiveQueryResp.QueryID)
|
||||
assert.Equal(t, 1, oneLiveQueryResp.TargetedHostCount)
|
||||
assert.Equal(t, 1, oneLiveQueryResp.RespondedHostCount)
|
||||
require.Len(t, oneLiveQueryResp.Results, 1)
|
||||
result = oneLiveQueryResp.Results[0]
|
||||
} else {
|
||||
} else if endpoint == oldEndpoint {
|
||||
require.Len(t, liveQueryResp.Results, 1)
|
||||
assert.Equal(t, 1, liveQueryResp.Summary.TargetedHostCount)
|
||||
assert.Equal(t, 1, liveQueryResp.Summary.RespondedHostCount)
|
||||
assert.Equal(t, q1.ID, liveQueryResp.Results[0].QueryID)
|
||||
require.Len(t, liveQueryResp.Results[0].Results, 1)
|
||||
result = liveQueryResp.Results[0].Results[0]
|
||||
} else { // customQueryOneHostEndpoint
|
||||
assert.Empty(t, liveQueryOnHostResp.Error)
|
||||
assert.Equal(t, host.ID, liveQueryOnHostResp.HostID)
|
||||
assert.Equal(t, fleet.StatusOnline, liveQueryOnHostResp.Status)
|
||||
assert.Equal(t, query, liveQueryOnHostResp.Query)
|
||||
result = fleet.QueryResult{
|
||||
HostID: liveQueryOnHostResp.HostID,
|
||||
Rows: liveQueryOnHostResp.Rows,
|
||||
}
|
||||
}
|
||||
assert.Equal(t, host.ID, result.HostID)
|
||||
require.Len(t, result.Rows, 1)
|
||||
|
|
@ -207,7 +258,9 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
|||
var activity *fleet.ActivityTypeLiveQuery
|
||||
activityUpdated := make(chan *fleet.ActivityTypeLiveQuery)
|
||||
go func() {
|
||||
for {
|
||||
ticker := time.NewTicker(100 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
details := json.RawMessage{}
|
||||
err := mysql.ExecAdhocSQLWithError(
|
||||
s.ds, func(q sqlx.ExtContext) error {
|
||||
|
|
@ -262,12 +315,13 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostOneQuery() {
|
|||
}
|
||||
}
|
||||
}
|
||||
s.Run("not saved query (old)", func() { test(false, false, true) })
|
||||
s.Run("saved query without stats (old)", func() { test(false, true, false) })
|
||||
s.Run("saved query with stats (old)", func() { test(false, true, true) })
|
||||
s.Run("not saved query", func() { test(true, false, true) })
|
||||
s.Run("saved query without stats", func() { test(true, true, false) })
|
||||
s.Run("saved query with stats", func() { test(true, true, true) })
|
||||
s.Run("not saved query (old)", func() { test(oldEndpoint, false, true) })
|
||||
s.Run("saved query without stats (old)", func() { test(oldEndpoint, true, false) })
|
||||
s.Run("saved query with stats (old)", func() { test(oldEndpoint, true, true) })
|
||||
s.Run("not saved query", func() { test(oneQueryEndpoint, false, true) })
|
||||
s.Run("saved query without stats", func() { test(oneQueryEndpoint, true, false) })
|
||||
s.Run("saved query with stats", func() { test(oneQueryEndpoint, true, true) })
|
||||
s.Run("custom query", func() { test(customQueryOneHostEndpoint, false, false) })
|
||||
}
|
||||
|
||||
func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
|
||||
|
|
@ -336,7 +390,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestOneHostMultipleQuery() {
|
|||
hostDistributedQueryPrefix + cid2: "some other msg",
|
||||
},
|
||||
Stats: map[string]*fleet.Stats{
|
||||
hostDistributedQueryPrefix + cid1: &fleet.Stats{
|
||||
hostDistributedQueryPrefix + cid1: {
|
||||
UserTime: uint64(1),
|
||||
SystemTime: uint64(2),
|
||||
},
|
||||
|
|
@ -452,7 +506,7 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestMultipleHostMultipleQuery() {
|
|||
hostDistributedQueryPrefix + cid2: "some other msg",
|
||||
},
|
||||
Stats: map[string]*fleet.Stats{
|
||||
hostDistributedQueryPrefix + cid1: &fleet.Stats{
|
||||
hostDistributedQueryPrefix + cid1: {
|
||||
UserTime: uint64(1),
|
||||
SystemTime: uint64(2),
|
||||
},
|
||||
|
|
@ -701,7 +755,6 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsToCreateCampaign() {
|
|||
oneLiveQueryResp := runOneLiveQueryResponse{}
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", 999), oneLiveQueryRequest, http.StatusNotFound, &oneLiveQueryResp)
|
||||
assert.Equal(t, 0, oneLiveQueryResp.RespondedHostCount)
|
||||
|
||||
}
|
||||
|
||||
func (s *liveQueriesTestSuite) TestLiveQueriesRestInvalidHost() {
|
||||
|
|
@ -735,7 +788,6 @@ func (s *liveQueriesTestSuite) TestLiveQueriesRestInvalidHost() {
|
|||
}
|
||||
oneLiveQueryResp := runOneLiveQueryResponse{}
|
||||
s.DoJSON("POST", fmt.Sprintf("/api/latest/fleet/queries/%d/run", q1.ID), oneLiveQueryRequest, http.StatusBadRequest, &oneLiveQueryResp)
|
||||
|
||||
}
|
||||
|
||||
func (s *liveQueriesTestSuite) TestLiveQueriesRestFailsOnSomeHost() {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package service
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
|
|
@ -27,6 +28,16 @@ type runOneLiveQueryRequest struct {
|
|||
HostIDs []uint `json:"host_ids"`
|
||||
}
|
||||
|
||||
type runLiveQueryOnHostRequest struct {
|
||||
Identifier string `url:"identifier"`
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
type runLiveQueryOnHostByIDRequest struct {
|
||||
HostID uint `url:"id"`
|
||||
Query string `json:"query"`
|
||||
}
|
||||
|
||||
type summaryPayload struct {
|
||||
TargetedHostCount int `json:"targeted_host_count"`
|
||||
RespondedHostCount int `json:"responded_host_count"`
|
||||
|
|
@ -51,13 +62,23 @@ type runOneLiveQueryResponse struct {
|
|||
|
||||
func (r runOneLiveQueryResponse) error() error { return r.Err }
|
||||
|
||||
type runLiveQueryOnHostResponse struct {
|
||||
HostID uint `json:"host_id"`
|
||||
Rows []map[string]string `json:"rows"`
|
||||
Query string `json:"query"`
|
||||
Status fleet.HostStatus `json:"status"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func (r runLiveQueryOnHostResponse) error() error { return nil }
|
||||
|
||||
func runOneLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
|
||||
req := request.(*runOneLiveQueryRequest)
|
||||
|
||||
// Only allow a host to be specified once in HostIDs
|
||||
hostIDs := server.RemoveDuplicatesFromSlice(req.HostIDs)
|
||||
|
||||
campaignResults, respondedHostCount, err := runLiveQuery(ctx, svc, []uint{req.QueryID}, hostIDs)
|
||||
campaignResults, respondedHostCount, err := runLiveQuery(ctx, svc, []uint{req.QueryID}, "", hostIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -89,7 +110,7 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
|
|||
// Only allow a host to be specified once in HostIDs
|
||||
hostIDs := server.RemoveDuplicatesFromSlice(req.HostIDs)
|
||||
|
||||
queryResults, respondedHostCount, err := runLiveQuery(ctx, svc, queryIDs, hostIDs)
|
||||
queryResults, respondedHostCount, err := runLiveQuery(ctx, svc, queryIDs, "", hostIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -104,7 +125,84 @@ func runLiveQueryEndpoint(ctx context.Context, request interface{}, svc fleet.Se
|
|||
return res, nil
|
||||
}
|
||||
|
||||
func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, hostIDs []uint) (
|
||||
func runLiveQueryOnHostEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
|
||||
req := request.(*runLiveQueryOnHostRequest)
|
||||
|
||||
if req.Query == "" {
|
||||
return nil, ctxerr.Wrap(ctx, badRequest("query is required"))
|
||||
}
|
||||
|
||||
host, err := svc.HostLiteByIdentifier(ctx, req.Identifier)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, badRequest(fmt.Sprintf("host not found: %s: %s", req.Identifier, err.Error())))
|
||||
}
|
||||
|
||||
return runLiveQueryOnHost(svc, ctx, host, req.Query)
|
||||
}
|
||||
|
||||
func runLiveQueryOnHostByIDEndpoint(ctx context.Context, request interface{}, svc fleet.Service) (errorer, error) {
|
||||
req := request.(*runLiveQueryOnHostByIDRequest)
|
||||
|
||||
if req.Query == "" {
|
||||
return nil, ctxerr.Wrap(ctx, badRequest("query is required"))
|
||||
}
|
||||
|
||||
host, err := svc.HostLiteByID(ctx, req.HostID)
|
||||
if err != nil {
|
||||
return nil, ctxerr.Wrap(ctx, badRequest(fmt.Sprintf("host not found: %d: %s", req.HostID, err.Error())))
|
||||
}
|
||||
|
||||
return runLiveQueryOnHost(svc, ctx, host, req.Query)
|
||||
}
|
||||
|
||||
func runLiveQueryOnHost(svc fleet.Service, ctx context.Context, host *fleet.HostLite, query string) (errorer, error) {
|
||||
res := runLiveQueryOnHostResponse{
|
||||
HostID: host.ID,
|
||||
Query: query,
|
||||
}
|
||||
|
||||
status := (&fleet.Host{
|
||||
DistributedInterval: host.DistributedInterval,
|
||||
ConfigTLSRefresh: host.ConfigTLSRefresh,
|
||||
SeenTime: host.SeenTime,
|
||||
}).Status(time.Now())
|
||||
switch status {
|
||||
case fleet.StatusOnline, fleet.StatusNew:
|
||||
res.Status = fleet.StatusOnline
|
||||
case fleet.StatusOffline, fleet.StatusMIA, fleet.StatusMissing:
|
||||
res.Status = fleet.StatusOffline
|
||||
return res, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown host status: %s", status)
|
||||
}
|
||||
|
||||
queryResults, _, err := runLiveQuery(ctx, svc, []uint{0}, query, []uint{host.ID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(queryResults) > 0 {
|
||||
var err error
|
||||
if queryResults[0].Err != nil {
|
||||
err = queryResults[0].Err
|
||||
} else if len(queryResults[0].Results) > 0 {
|
||||
queryResult := queryResults[0].Results[0]
|
||||
if queryResult.Error != nil {
|
||||
err = errors.New(*queryResult.Error)
|
||||
}
|
||||
res.Rows = queryResult.Rows
|
||||
res.HostID = queryResult.HostID
|
||||
} else { // timeout waiting for results
|
||||
err = errors.New("timeout waiting for results")
|
||||
}
|
||||
if err != nil {
|
||||
res.Error = err.Error()
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, query string, hostIDs []uint) (
|
||||
[]fleet.QueryCampaignResult, int, error,
|
||||
) {
|
||||
// The period used here should always be less than the request timeout for any load
|
||||
|
|
@ -119,7 +217,7 @@ func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, hostI
|
|||
logging.WithExtras(ctx, "live_query_rest_period_err", err)
|
||||
}
|
||||
|
||||
queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, queryIDs, hostIDs, duration)
|
||||
queryResults, respondedHostCount, err := svc.RunLiveQueryDeadline(ctx, queryIDs, query, hostIDs, duration)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
|
@ -142,7 +240,7 @@ func runLiveQuery(ctx context.Context, svc fleet.Service, queryIDs []uint, hostI
|
|||
}
|
||||
|
||||
func (svc *Service) RunLiveQueryDeadline(
|
||||
ctx context.Context, queryIDs []uint, hostIDs []uint, deadline time.Duration,
|
||||
ctx context.Context, queryIDs []uint, query string, hostIDs []uint, deadline time.Duration,
|
||||
) ([]fleet.QueryCampaignResult, int, error) {
|
||||
if len(queryIDs) == 0 || len(hostIDs) == 0 {
|
||||
svc.authz.SkipAuthorization(ctx)
|
||||
|
|
@ -160,11 +258,19 @@ func (svc *Service) RunLiveQueryDeadline(
|
|||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
campaign, err := svc.NewDistributedQueryCampaign(ctx, "", &queryID, fleet.HostTargets{HostIDs: hostIDs})
|
||||
queryIDPtr := &queryID
|
||||
queryString := ""
|
||||
// 0 is a special ID that indicates we should use raw SQL query instead
|
||||
if queryID == 0 {
|
||||
queryIDPtr = nil
|
||||
queryString = query
|
||||
}
|
||||
campaign, err := svc.NewDistributedQueryCampaign(ctx, queryString, queryIDPtr, fleet.HostTargets{HostIDs: hostIDs})
|
||||
if err != nil {
|
||||
resultsCh <- fleet.QueryCampaignResult{QueryID: queryID, Error: ptr.String(err.Error()), Err: err}
|
||||
return
|
||||
}
|
||||
queryID = campaign.QueryID
|
||||
|
||||
readChan, cancelFunc, err := svc.GetCampaignReader(ctx, campaign)
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Reference in a new issue