mirror of
https://github.com/fleetdm/fleet
synced 2026-05-06 14:58:33 +00:00
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #32220 # Details This PR fixes an issue where hosts whose running scripts were canceled by Orbit (e.g. due to timing out) were reported as being still "pending" on the batch script details view. This was due to our only counting runs as errored if the error code was > 0, and ignoring negative error codes (which is what Orbit uses for this case). # Checklist for submitter If some of the following don't apply, delete the relevant line. - [X] Input data is properly validated, `SELECT *` is avoided, SQL injection is prevented (using placeholders for values in statements) ## Testing - [X] Added/updated automated tests Changed a couple of places where we were using `1` for an error code to `-1` - [X] Where appropriate, [automated tests simulate multiple hosts and test for host isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing) (updates to one hosts's records do not affect another) - [X] QA'd all new/changed functionality manually For unreleased bug fixes in a release candidate, one of: - [X] Confirmed that the fix is not expected to adversely impact load test results - [X] Alerted the release DRI if additional load testing is needed
2537 lines
79 KiB
Go
2537 lines
79 KiB
Go
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"crypto/md5" //nolint:gosec
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
constants "github.com/fleetdm/fleet/v4/pkg/scripts"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
"github.com/go-kit/log/level"
|
|
"github.com/google/uuid"
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
func (ds *Datastore) NewHostScriptExecutionRequest(ctx context.Context, request *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) {
|
|
var res *fleet.HostScriptResult
|
|
return res, ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
if request.ScriptContentID == 0 {
|
|
// then we are doing a sync execution, so create the contents first
|
|
scRes, err := insertScriptContents(ctx, tx, request.ScriptContents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, _ := scRes.LastInsertId()
|
|
request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115
|
|
}
|
|
res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, false)
|
|
return err
|
|
})
|
|
}
|
|
|
|
func (ds *Datastore) newHostScriptExecutionRequest(ctx context.Context, tx sqlx.ExtContext, request *fleet.HostScriptRequestPayload, isInternal bool) (*fleet.HostScriptResult, error) {
|
|
const (
|
|
getStmt = `
|
|
SELECT
|
|
ua.id, ua.host_id, ua.execution_id, ua.created_at, sua.script_id, sua.policy_id, ua.user_id,
|
|
payload->'$.sync_request' AS sync_request,
|
|
sc.contents as script_contents, sua.setup_experience_script_id
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
INNER JOIN script_contents sc
|
|
ON sua.script_content_id = sc.id
|
|
WHERE
|
|
ua.id = ?
|
|
`
|
|
)
|
|
|
|
_, activityID, err := ds.insertNewHostScriptExecution(ctx, tx, request, isInternal)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "inserting new script execution request")
|
|
}
|
|
|
|
var script fleet.HostScriptResult
|
|
err = sqlx.GetContext(ctx, tx, &script, getStmt, activityID)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "getting the created host script activity to return")
|
|
}
|
|
|
|
return &script, nil
|
|
}
|
|
|
|
func (ds *Datastore) insertNewHostScriptExecution(ctx context.Context, tx sqlx.ExtContext, request *fleet.HostScriptRequestPayload, isInternal bool) (string, int64, error) {
|
|
const (
|
|
insUAStmt = `
|
|
INSERT INTO upcoming_activities
|
|
(host_id, priority, user_id, fleet_initiated, activity_type, execution_id, payload)
|
|
VALUES
|
|
(?, ?, ?, ?, 'script', ?,
|
|
JSON_OBJECT(
|
|
'sync_request', ?,
|
|
'is_internal', ?,
|
|
'user', (SELECT JSON_OBJECT('name', name, 'email', email, 'gravatar_url', gravatar_url) FROM users WHERE id = ?)
|
|
)
|
|
)`
|
|
|
|
insSUAStmt = `
|
|
INSERT INTO script_upcoming_activities
|
|
(upcoming_activity_id, script_id, script_content_id, policy_id, setup_experience_script_id)
|
|
VALUES
|
|
(?, ?, ?, ?, ?)
|
|
`
|
|
)
|
|
|
|
execID := uuid.New().String()
|
|
result, err := tx.ExecContext(ctx, insUAStmt,
|
|
request.HostID,
|
|
request.Priority(),
|
|
request.UserID,
|
|
request.PolicyID != nil, // fleet-initiated if request is via a policy failure
|
|
execID,
|
|
request.SyncRequest,
|
|
isInternal,
|
|
request.UserID,
|
|
)
|
|
if err != nil {
|
|
return "", 0, ctxerr.Wrap(ctx, err, "new script upcoming activity")
|
|
}
|
|
|
|
activityID, _ := result.LastInsertId()
|
|
_, err = tx.ExecContext(ctx, insSUAStmt,
|
|
activityID,
|
|
request.ScriptID,
|
|
request.ScriptContentID,
|
|
request.PolicyID,
|
|
request.SetupExperienceScriptID,
|
|
)
|
|
if err != nil {
|
|
return "", 0, ctxerr.Wrap(ctx, err, "new join script upcoming activity")
|
|
}
|
|
|
|
if _, err := ds.activateNextUpcomingActivity(ctx, tx, request.HostID, ""); err != nil {
|
|
return "", 0, ctxerr.Wrap(ctx, err, "activate next activity")
|
|
}
|
|
|
|
return execID, activityID, nil
|
|
}
|
|
|
|
func truncateScriptResult(output string) string {
|
|
const maxOutputRuneLen = 10000
|
|
if len(output) > utf8.UTFMax*maxOutputRuneLen {
|
|
// truncate the bytes as we know the output is too long, no point
|
|
// converting more bytes than needed to runes.
|
|
output = output[len(output)-(utf8.UTFMax*maxOutputRuneLen):]
|
|
}
|
|
if utf8.RuneCountInString(output) > maxOutputRuneLen {
|
|
outputRunes := []rune(output)
|
|
output = string(outputRunes[len(outputRunes)-maxOutputRuneLen:])
|
|
}
|
|
return output
|
|
}
|
|
|
|
func (ds *Datastore) SetHostScriptExecutionResult(ctx context.Context, result *fleet.HostScriptResultPayload) (*fleet.HostScriptResult,
|
|
string, error,
|
|
) {
|
|
const resultExistsStmt = `
|
|
SELECT
|
|
1
|
|
FROM
|
|
host_script_results
|
|
WHERE
|
|
host_id = ? AND
|
|
execution_id = ? AND
|
|
exit_code IS NOT NULL
|
|
`
|
|
|
|
const updStmt = `
|
|
UPDATE host_script_results SET
|
|
output = ?,
|
|
runtime = ?,
|
|
exit_code = ?,
|
|
timeout = ?
|
|
WHERE
|
|
host_id = ? AND
|
|
execution_id = ?`
|
|
|
|
const hostMDMActionsStmt = `
|
|
SELECT 'uninstall' AS action
|
|
FROM
|
|
host_software_installs
|
|
WHERE
|
|
execution_id = :execution_id AND host_id = :host_id
|
|
UNION -- host_mdm_actions query (and thus row in union) must be last to avoid #25144
|
|
SELECT
|
|
CASE
|
|
WHEN lock_ref = :execution_id THEN 'lock_ref'
|
|
WHEN unlock_ref = :execution_id THEN 'unlock_ref'
|
|
WHEN wipe_ref = :execution_id THEN 'wipe_ref'
|
|
ELSE ''
|
|
END AS action
|
|
FROM
|
|
host_mdm_actions
|
|
WHERE
|
|
host_id = :host_id
|
|
`
|
|
|
|
output := truncateScriptResult(result.Output)
|
|
|
|
var hsr *fleet.HostScriptResult
|
|
var action string
|
|
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var resultExists bool
|
|
err := sqlx.GetContext(ctx, tx, &resultExists, resultExistsStmt, result.HostID, result.ExecutionID)
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return ctxerr.Wrap(ctx, err, "check if host script result exists")
|
|
}
|
|
if resultExists {
|
|
level.Debug(ds.logger).Log("msg", "duplicate script execution result sent, will be ignored (original result is preserved)",
|
|
"host_id", result.HostID,
|
|
"execution_id", result.ExecutionID,
|
|
)
|
|
|
|
// still do the activate next activity to ensure progress as there was
|
|
// an unexpected flow if we get here.
|
|
if _, err := ds.activateNextUpcomingActivity(ctx, tx, result.HostID, result.ExecutionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "activate next activity")
|
|
}
|
|
|
|
// succeed but leave hsr nil
|
|
return nil
|
|
}
|
|
|
|
res, err := tx.ExecContext(ctx, updStmt,
|
|
output,
|
|
result.Runtime,
|
|
// Windows error codes are signed 32-bit integers, but are
|
|
// returned as unsigned integers by the windows API. The
|
|
// software that receives them is responsible for casting
|
|
// it to a 32-bit signed integer.
|
|
// See /orbit/pkg/scripts/exec_windows.go
|
|
int32(result.ExitCode), //nolint:gosec // dismiss G115
|
|
result.Timeout,
|
|
result.HostID,
|
|
result.ExecutionID,
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "update host script result")
|
|
}
|
|
|
|
if n, _ := res.RowsAffected(); n > 0 {
|
|
// it did update, so return the updated result
|
|
hsr, err = ds.getHostScriptExecutionResultDB(ctx, tx, result.ExecutionID, scriptExecutionSearchOpts{IncludeCanceled: true})
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load updated host script result")
|
|
}
|
|
|
|
// look up if that script was a lock/unlock/wipe/uninstall script for that host,
|
|
// and if so update the host_mdm_actions table accordingly.
|
|
namedArgs := map[string]any{
|
|
"host_id": result.HostID,
|
|
"execution_id": result.ExecutionID,
|
|
}
|
|
stmt, args, err := sqlx.Named(hostMDMActionsStmt, namedArgs)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build named query for host mdm actions")
|
|
}
|
|
err = sqlx.GetContext(ctx, tx, &action, stmt, args...)
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) { // ignore ErrNoRows, refCol will be empty
|
|
return ctxerr.Wrap(ctx, err, "lookup host script corresponding mdm action")
|
|
}
|
|
|
|
switch action {
|
|
case "":
|
|
// do nothing
|
|
case "uninstall":
|
|
err = ds.updateUninstallStatusFromResult(ctx, tx, result.HostID, result.ExecutionID, result.ExitCode)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "update host uninstall action based on script result")
|
|
}
|
|
default: // lock/unlock/wipe
|
|
err = updateHostLockWipeStatusFromResult(ctx, tx, result.HostID, action, result.ExitCode == 0)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "update host mdm action based on script result")
|
|
}
|
|
}
|
|
}
|
|
|
|
if _, err := ds.activateNextUpcomingActivity(ctx, tx, result.HostID, result.ExecutionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "activate next activity")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
return hsr, action, nil
|
|
}
|
|
|
|
func (ds *Datastore) ListPendingHostScriptExecutions(ctx context.Context, hostID uint, onlyShowInternal bool) ([]*fleet.HostScriptResult, error) {
|
|
return ds.listUpcomingHostScriptExecutions(ctx, hostID, onlyShowInternal, false)
|
|
}
|
|
|
|
func (ds *Datastore) ListReadyToExecuteScriptsForHost(ctx context.Context, hostID uint, onlyShowInternal bool) ([]*fleet.HostScriptResult, error) {
|
|
return ds.listUpcomingHostScriptExecutions(ctx, hostID, onlyShowInternal, true)
|
|
}
|
|
|
|
func (ds *Datastore) listUpcomingHostScriptExecutions(ctx context.Context, hostID uint, onlyShowInternal, onlyReadyToExecute bool) ([]*fleet.HostScriptResult, error) {
|
|
extraWhere := ""
|
|
if onlyShowInternal {
|
|
// software_uninstalls are implicitly internal
|
|
extraWhere = " AND COALESCE(ua.payload->'$.is_internal', 1) = 1"
|
|
}
|
|
if onlyReadyToExecute {
|
|
extraWhere += " AND ua.activated_at IS NOT NULL"
|
|
}
|
|
// this selects software uninstalls too as they run as scripts
|
|
listStmt := fmt.Sprintf(`
|
|
SELECT
|
|
id,
|
|
host_id,
|
|
execution_id,
|
|
script_id,
|
|
created_at
|
|
FROM (
|
|
SELECT
|
|
ua.id,
|
|
ua.host_id,
|
|
ua.execution_id,
|
|
sua.script_id,
|
|
ua.priority,
|
|
ua.created_at,
|
|
IF(ua.activated_at IS NULL, 0, 1) AS topmost
|
|
FROM
|
|
upcoming_activities ua
|
|
-- left join because software_uninstall has no script join
|
|
LEFT JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.host_id = ? AND
|
|
ua.activity_type IN ('script', 'software_uninstall')
|
|
%s
|
|
ORDER BY topmost DESC, priority DESC, created_at ASC) t`, extraWhere)
|
|
|
|
var results []*fleet.HostScriptResult
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, listStmt, hostID); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "list pending host script executions")
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (ds *Datastore) IsExecutionPendingForHost(ctx context.Context, hostID uint, scriptID uint) (bool, error) {
|
|
const getStmt = `
|
|
SELECT
|
|
1
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.host_id = ? AND
|
|
ua.activity_type = 'script' AND
|
|
sua.script_id = ?
|
|
`
|
|
|
|
var results []*uint
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, getStmt, hostID, scriptID); err != nil {
|
|
return false, ctxerr.Wrap(ctx, err, "is execution pending for host")
|
|
}
|
|
return len(results) > 0, nil
|
|
}
|
|
|
|
type scriptExecutionSearchOpts struct {
|
|
IncludeCanceled bool
|
|
UninstallHostID uint
|
|
}
|
|
|
|
func (ds *Datastore) GetHostScriptExecutionResult(ctx context.Context, execID string) (*fleet.HostScriptResult, error) {
|
|
return ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), execID, scriptExecutionSearchOpts{})
|
|
}
|
|
|
|
func (ds *Datastore) GetSelfServiceUninstallScriptExecutionResult(ctx context.Context, execID string, hostID uint) (*fleet.HostScriptResult, error) {
|
|
return ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), execID, scriptExecutionSearchOpts{UninstallHostID: hostID})
|
|
}
|
|
|
|
func (ds *Datastore) getHostScriptExecutionResultDB(ctx context.Context, q sqlx.QueryerContext, execID string, opts scriptExecutionSearchOpts) (*fleet.HostScriptResult, error) {
|
|
var activeParams []any
|
|
|
|
canceledCondition := ""
|
|
if !opts.IncludeCanceled {
|
|
canceledCondition = " AND hsr.canceled = 0"
|
|
}
|
|
|
|
uninstallCondition := ""
|
|
if opts.UninstallHostID > 0 {
|
|
uninstallCondition = `JOIN host_software_installs hsi ON hsi.execution_id = hsr.execution_id
|
|
AND hsi.uninstall = TRUE AND hsr.host_id = ?`
|
|
activeParams = append(activeParams, opts.UninstallHostID)
|
|
}
|
|
|
|
activeParams = append(activeParams, execID)
|
|
|
|
getActiveStmt := fmt.Sprintf(`
|
|
SELECT
|
|
hsr.id,
|
|
hsr.host_id,
|
|
hsr.execution_id,
|
|
sc.contents as script_contents,
|
|
hsr.script_id,
|
|
hsr.policy_id,
|
|
hsr.output,
|
|
hsr.runtime,
|
|
hsr.exit_code,
|
|
hsr.timeout,
|
|
hsr.created_at,
|
|
hsr.user_id,
|
|
hsr.sync_request,
|
|
hsr.host_deleted_at,
|
|
hsr.setup_experience_script_id,
|
|
hsr.canceled,
|
|
bahr.batch_execution_id
|
|
FROM
|
|
host_script_results hsr
|
|
LEFT JOIN
|
|
batch_activity_host_results bahr ON hsr.execution_id = bahr.host_execution_id
|
|
JOIN
|
|
script_contents sc
|
|
%s
|
|
WHERE
|
|
hsr.execution_id = ? AND
|
|
hsr.script_content_id = sc.id
|
|
%s
|
|
`, uninstallCondition, canceledCondition)
|
|
|
|
// We don't include upcoming uninstall script executions in results (different activity type, and they're blank anyway)
|
|
const getUpcomingStmt = `
|
|
SELECT
|
|
0 as id,
|
|
ua.host_id,
|
|
ua.execution_id,
|
|
sc.contents as script_contents,
|
|
sua.script_id,
|
|
sua.policy_id,
|
|
'' as output,
|
|
0 as runtime,
|
|
NULL as exit_code,
|
|
NULL as timeout,
|
|
ua.created_at,
|
|
ua.user_id,
|
|
COALESCE(ua.payload->'$.sync_request', 0) as sync_request,
|
|
NULL as host_deleted_at,
|
|
sua.setup_experience_script_id,
|
|
0 as canceled
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
INNER JOIN
|
|
script_contents sc
|
|
ON sua.script_content_id = sc.id
|
|
WHERE
|
|
ua.execution_id = ? AND
|
|
ua.activity_type = 'script'
|
|
`
|
|
|
|
var result fleet.HostScriptResult
|
|
if err := sqlx.GetContext(ctx, q, &result, getActiveStmt, activeParams...); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
// try with upcoming activities
|
|
err = sqlx.GetContext(ctx, q, &result, getUpcomingStmt, execID)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ctxerr.Wrap(ctx, notFound("HostScriptResult").WithName(execID))
|
|
}
|
|
}
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get host script result")
|
|
}
|
|
}
|
|
return &result, nil
|
|
}
|
|
|
|
func (ds *Datastore) NewScript(ctx context.Context, script *fleet.Script) (*fleet.Script, error) {
|
|
var res sql.Result
|
|
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
|
|
// first insert script contents
|
|
scRes, err := insertScriptContents(ctx, tx, script.ScriptContents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
id, _ := scRes.LastInsertId()
|
|
|
|
// then create the script entity
|
|
res, err = insertScript(ctx, tx, script, uint(id)) //nolint:gosec // dismiss G115
|
|
return err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
id, _ := res.LastInsertId()
|
|
return ds.getScriptDB(ctx, ds.writer(ctx), uint(id)) //nolint:gosec // dismiss G115
|
|
}
|
|
|
|
func (ds *Datastore) UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error) {
|
|
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
// Get the current script_content_id
|
|
var oldContentID int64
|
|
getCurrentStmt := `SELECT script_content_id FROM scripts WHERE id = ?`
|
|
err := sqlx.GetContext(ctx, tx, &oldContentID, getCurrentStmt, scriptID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "getting current script content id")
|
|
}
|
|
|
|
// Insert or get existing content (insertScriptContents handles deduplication)
|
|
scRes, err := insertScriptContents(ctx, tx, scriptContents)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "inserting/getting script contents")
|
|
}
|
|
newContentID, _ := scRes.LastInsertId()
|
|
|
|
// Update the script to point to the new content
|
|
if newContentID != oldContentID {
|
|
updateStmt := `
|
|
UPDATE scripts
|
|
SET script_content_id = ?
|
|
WHERE id = ?
|
|
`
|
|
_, err = tx.ExecContext(ctx, updateStmt, newContentID, scriptID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "updating script content reference")
|
|
}
|
|
|
|
// Try to clean up the old content if no longer used
|
|
// Don't fail the transaction if cleanup fails; just log it
|
|
if err := ds.cleanupScriptContent(ctx, tx, uint(oldContentID)); err != nil { //nolint:gosec
|
|
level.Error(ds.logger).Log("msg", "failed to cleanup orphaned script content",
|
|
"script_id", scriptID, "old_content_id", oldContentID, "err", err)
|
|
ctxerr.Handle(ctx, err)
|
|
}
|
|
} else {
|
|
// Just update the timestamp
|
|
_, err = tx.ExecContext(ctx, "UPDATE scripts SET updated_at = NOW() WHERE id = ?", scriptID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "updating script updated_at time")
|
|
}
|
|
}
|
|
|
|
// Cancel pending executions
|
|
if err := ds.cancelUpcomingScriptActivities(ctx, tx, scriptID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "canceling upcoming script executions")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "updating script contents")
|
|
}
|
|
return ds.Script(ctx, scriptID)
|
|
}
|
|
|
|
func (ds *Datastore) cancelUpcomingScriptActivities(ctx context.Context, db sqlx.ExtContext, scriptID uint) error {
|
|
const stmt = `
|
|
SELECT
|
|
ua.execution_id,
|
|
ua.host_id
|
|
FROM
|
|
script_upcoming_activities sua
|
|
INNER JOIN
|
|
upcoming_activities ua ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
sua.script_id = ?
|
|
`
|
|
|
|
var upcomingExecutions []struct {
|
|
ExecutionID string `db:"execution_id"`
|
|
HostID uint `db:"host_id"`
|
|
}
|
|
|
|
if err := sqlx.SelectContext(ctx, db, &upcomingExecutions, stmt, scriptID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "selecting upcoming script executions")
|
|
}
|
|
|
|
for _, upcomingExecution := range upcomingExecutions {
|
|
if _, err := ds.cancelHostUpcomingActivity(ctx, db, upcomingExecution.HostID, upcomingExecution.ExecutionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "canceling upcoming activity")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func insertScript(ctx context.Context, tx sqlx.ExtContext, script *fleet.Script, scriptContentsID uint) (sql.Result, error) {
|
|
const insertStmt = `
|
|
INSERT INTO
|
|
scripts (
|
|
team_id, global_or_team_id, name, script_content_id
|
|
)
|
|
VALUES
|
|
(?, ?, ?, ?)
|
|
`
|
|
var globalOrTeamID uint
|
|
if script.TeamID != nil {
|
|
globalOrTeamID = *script.TeamID
|
|
}
|
|
res, err := tx.ExecContext(ctx, insertStmt,
|
|
script.TeamID, globalOrTeamID, script.Name, scriptContentsID)
|
|
if err != nil {
|
|
if IsDuplicate(err) {
|
|
// name already exists for this team/global
|
|
err = alreadyExists("Script", script.Name)
|
|
} else if isChildForeignKeyError(err) {
|
|
// team does not exist
|
|
err = foreignKey("scripts", fmt.Sprintf("team_id=%v", script.TeamID))
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "insert script")
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func insertScriptContents(ctx context.Context, tx sqlx.ExtContext, contents string) (sql.Result, error) {
|
|
const insertStmt = `
|
|
INSERT INTO
|
|
script_contents (
|
|
md5_checksum, contents
|
|
)
|
|
VALUES (UNHEX(?),?)
|
|
ON DUPLICATE KEY UPDATE
|
|
id=LAST_INSERT_ID(id)
|
|
`
|
|
|
|
md5Checksum := md5ChecksumScriptContent(contents)
|
|
res, err := tx.ExecContext(ctx, insertStmt, md5Checksum, contents)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "insert script contents")
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func md5ChecksumScriptContent(s string) string {
|
|
return md5ChecksumBytes([]byte(s))
|
|
}
|
|
|
|
func md5ChecksumBytes(b []byte) string {
|
|
rawChecksum := md5.Sum(b) //nolint:gosec
|
|
return strings.ToUpper(hex.EncodeToString(rawChecksum[:]))
|
|
}
|
|
|
|
func (ds *Datastore) cleanupScriptContent(ctx context.Context, tx sqlx.ExtContext, contentID uint) error {
|
|
// Check if this content is still being used anywhere
|
|
var usageCount int
|
|
stmt := `
|
|
SELECT COUNT(*) FROM (
|
|
SELECT 1 FROM scripts WHERE script_content_id = ?
|
|
UNION ALL
|
|
SELECT 1 FROM setup_experience_scripts WHERE script_content_id = ?
|
|
UNION ALL
|
|
SELECT 1 FROM software_installers WHERE
|
|
install_script_content_id = ?
|
|
OR uninstall_script_content_id = ?
|
|
OR post_install_script_content_id = ?
|
|
UNION ALL
|
|
SELECT 1 FROM script_upcoming_activities WHERE script_content_id = ?
|
|
UNION ALL
|
|
SELECT 1 FROM host_script_results WHERE script_content_id = ?
|
|
) t
|
|
`
|
|
err := sqlx.GetContext(ctx, tx, &usageCount, stmt,
|
|
contentID, contentID, contentID, contentID, contentID, contentID, contentID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "checking script content usage for cleanup")
|
|
}
|
|
|
|
if usageCount == 0 {
|
|
// Not being used, safe to delete
|
|
deleteStmt := `DELETE FROM script_contents WHERE id = ?`
|
|
_, err = tx.ExecContext(ctx, deleteStmt, contentID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "deleting unused script content")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) Script(ctx context.Context, id uint) (*fleet.Script, error) {
|
|
return ds.getScriptDB(ctx, ds.reader(ctx), id)
|
|
}
|
|
|
|
func (ds *Datastore) getScriptDB(ctx context.Context, q sqlx.QueryerContext, id uint) (*fleet.Script, error) {
|
|
const getStmt = `
|
|
SELECT
|
|
id,
|
|
team_id,
|
|
name,
|
|
created_at,
|
|
updated_at,
|
|
script_content_id
|
|
FROM
|
|
scripts
|
|
WHERE
|
|
id = ?
|
|
`
|
|
var script fleet.Script
|
|
if err := sqlx.GetContext(ctx, q, &script, getStmt, id); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, notFound("Script").WithID(id)
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "get script")
|
|
}
|
|
return &script, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetScriptContents(ctx context.Context, id uint) ([]byte, error) {
|
|
const getStmt = `
|
|
SELECT
|
|
sc.contents
|
|
FROM
|
|
script_contents sc
|
|
JOIN scripts s ON s.script_content_id = sc.id
|
|
WHERE
|
|
s.id = ?
|
|
`
|
|
var contents []byte
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &contents, getStmt, id); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, notFound("Script").WithID(id)
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "get script contents")
|
|
}
|
|
return contents, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetAnyScriptContents(ctx context.Context, id uint) ([]byte, error) {
|
|
const getStmt = `
|
|
SELECT
|
|
sc.contents
|
|
FROM
|
|
script_contents sc
|
|
WHERE
|
|
sc.id = ?
|
|
`
|
|
var contents []byte
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &contents, getStmt, id); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, notFound("Script").WithID(id)
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "get any script contents")
|
|
}
|
|
return contents, nil
|
|
}
|
|
|
|
var errDeleteScriptWithAssociatedPolicy = &fleet.ConflictError{Message: "Couldn't delete. Policy automation uses this script. Please remove this script from associated policy automations and try again."}
|
|
|
|
func (ds *Datastore) DeleteScript(ctx context.Context, id uint) error {
|
|
var activateAffectedHosts []uint
|
|
|
|
err := ds.withTx(ctx, func(tx sqlx.ExtContext) error {
|
|
_, err := tx.ExecContext(ctx, `DELETE FROM host_script_results WHERE script_id = ?
|
|
AND exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)`,
|
|
id, int(constants.MaxServerWaitTime.Seconds()),
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "cancel pending script executions")
|
|
}
|
|
|
|
// load hosts that will have their upcoming_activities deleted, if that
|
|
// activity is "activated", as that means we will have to call
|
|
// activateNextUpcomingActivity for those hosts.
|
|
loadAffectedHostsStmt := `
|
|
SELECT
|
|
DISTINCT host_id
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE sua.script_id = ? AND
|
|
ua.activity_type = 'script' AND
|
|
ua.activated_at IS NOT NULL AND
|
|
(ua.payload->'$.sync_request' = 0 OR
|
|
ua.created_at >= NOW() - INTERVAL ? SECOND)`
|
|
var affectedHosts []uint
|
|
if err := sqlx.SelectContext(ctx, tx, &affectedHosts, loadAffectedHostsStmt,
|
|
id, int(constants.MaxServerWaitTime.Seconds())); err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "load affected hosts")
|
|
}
|
|
activateAffectedHosts = affectedHosts
|
|
|
|
_, err = tx.ExecContext(ctx, `DELETE FROM upcoming_activities
|
|
USING upcoming_activities
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON upcoming_activities.id = sua.upcoming_activity_id
|
|
WHERE sua.script_id = ? AND
|
|
upcoming_activities.activity_type = 'script' AND
|
|
(upcoming_activities.payload->'$.sync_request' = 0 OR
|
|
upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND)
|
|
`,
|
|
id, int(constants.MaxServerWaitTime.Seconds()),
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "cancel upcoming pending script executions")
|
|
}
|
|
|
|
_, err = tx.ExecContext(ctx, `DELETE FROM scripts WHERE id = ?`, id)
|
|
if err != nil {
|
|
if isMySQLForeignKey(err) {
|
|
// Check if the script is referenced by a policy automation.
|
|
var count int
|
|
if err := sqlx.GetContext(ctx, tx, &count, `SELECT COUNT(*) FROM policies WHERE script_id = ?`, id); err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "getting reference from policies")
|
|
}
|
|
if count > 0 {
|
|
return ctxerr.Wrap(ctx, errDeleteScriptWithAssociatedPolicy, "delete script")
|
|
}
|
|
}
|
|
return ctxerr.Wrap(ctx, err, "delete script")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// we call this outside of the transaction to avoid a
|
|
// long-running/deadlock-prone transaction, as many hosts could be affected.
|
|
return ds.activateNextUpcomingActivityForBatchOfHosts(ctx, activateAffectedHosts)
|
|
}
|
|
|
|
// deletePendingHostScriptExecutionsForPolicy should be called when a policy is deleted to remove any pending script executions
|
|
func (ds *Datastore) deletePendingHostScriptExecutionsForPolicy(ctx context.Context, teamID *uint, policyID uint) error {
|
|
var globalOrTeamID uint
|
|
if teamID != nil {
|
|
globalOrTeamID = *teamID
|
|
}
|
|
|
|
deletePendingFunc := func(stmt string, args ...any) error {
|
|
_, err := ds.writer(ctx).ExecContext(ctx, stmt, args...)
|
|
return ctxerr.Wrap(ctx, err, "delete pending host script executions for policy")
|
|
}
|
|
|
|
deleteHSRStmt := `
|
|
DELETE FROM
|
|
host_script_results
|
|
WHERE
|
|
policy_id = ? AND
|
|
script_id IN (
|
|
SELECT id FROM scripts WHERE scripts.global_or_team_id = ?
|
|
) AND
|
|
exit_code IS NULL
|
|
`
|
|
|
|
if err := deletePendingFunc(deleteHSRStmt, policyID, globalOrTeamID); err != nil {
|
|
return err
|
|
}
|
|
|
|
loadAffectedHostsStmt := `
|
|
SELECT
|
|
DISTINCT host_id
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.activity_type = 'script' AND
|
|
ua.activated_at IS NOT NULL AND
|
|
sua.policy_id = ? AND
|
|
sua.script_id IN (
|
|
SELECT id FROM scripts WHERE scripts.global_or_team_id = ?
|
|
)`
|
|
var affectedHosts []uint
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &affectedHosts,
|
|
loadAffectedHostsStmt, policyID, globalOrTeamID); err != nil {
|
|
return err
|
|
}
|
|
|
|
deleteUAStmt := `
|
|
DELETE FROM
|
|
upcoming_activities
|
|
USING
|
|
upcoming_activities
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON upcoming_activities.id = sua.upcoming_activity_id
|
|
WHERE
|
|
upcoming_activities.activity_type = 'script' AND
|
|
sua.policy_id = ? AND
|
|
sua.script_id IN (
|
|
SELECT id FROM scripts WHERE scripts.global_or_team_id = ?
|
|
)
|
|
`
|
|
if err := deletePendingFunc(deleteUAStmt, policyID, globalOrTeamID); err != nil {
|
|
return err
|
|
}
|
|
|
|
return ds.activateNextUpcomingActivityForBatchOfHosts(ctx, affectedHosts)
|
|
}
|
|
|
|
func (ds *Datastore) ListScripts(ctx context.Context, teamID *uint, opt fleet.ListOptions) ([]*fleet.Script, *fleet.PaginationMetadata, error) {
|
|
var scripts []*fleet.Script
|
|
|
|
const selectStmt = `
|
|
SELECT
|
|
s.id,
|
|
s.team_id,
|
|
s.name,
|
|
s.created_at,
|
|
s.updated_at
|
|
FROM
|
|
scripts s
|
|
WHERE
|
|
s.global_or_team_id = ?
|
|
`
|
|
var globalOrTeamID uint
|
|
if teamID != nil {
|
|
globalOrTeamID = *teamID
|
|
}
|
|
|
|
args := []any{globalOrTeamID}
|
|
stmt, args := appendListOptionsWithCursorToSQL(selectStmt, args, &opt)
|
|
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &scripts, stmt, args...); err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "select scripts")
|
|
}
|
|
|
|
var metaData *fleet.PaginationMetadata
|
|
if opt.IncludeMetadata {
|
|
metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0}
|
|
if len(scripts) > int(opt.PerPage) { //nolint:gosec // dismiss G115
|
|
metaData.HasNextResults = true
|
|
scripts = scripts[:len(scripts)-1]
|
|
}
|
|
}
|
|
return scripts, metaData, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetScriptIDByName(ctx context.Context, name string, teamID *uint) (uint, error) {
|
|
const selectStmt = `
|
|
SELECT
|
|
id
|
|
FROM
|
|
scripts
|
|
WHERE
|
|
global_or_team_id = ?
|
|
AND name = ?
|
|
`
|
|
var globalOrTeamID uint
|
|
if teamID != nil {
|
|
globalOrTeamID = *teamID
|
|
}
|
|
|
|
var id uint
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &id, selectStmt, globalOrTeamID, name); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return 0, notFound("Script").WithName(name)
|
|
}
|
|
return 0, ctxerr.Wrap(ctx, err, "get script by name")
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetHostScriptDetails(ctx context.Context, hostID uint, teamID *uint, opt fleet.ListOptions, hostPlatform string) ([]*fleet.HostScriptDetail, *fleet.PaginationMetadata, error) {
|
|
var globalOrTeamID uint
|
|
if teamID != nil {
|
|
globalOrTeamID = *teamID
|
|
}
|
|
|
|
var extension string
|
|
switch {
|
|
case hostPlatform == "windows":
|
|
// filter by .ps1 extension
|
|
extension = `%.ps1`
|
|
case fleet.IsUnixLike(hostPlatform):
|
|
// filter by .sh extension
|
|
extension = `%.sh`
|
|
default:
|
|
// no extension filter
|
|
}
|
|
|
|
type row struct {
|
|
ScriptID uint `db:"script_id"`
|
|
Name string `db:"name"`
|
|
HSRID *uint `db:"hsr_id"`
|
|
ExecutionID *string `db:"execution_id"`
|
|
ExecutedAt *time.Time `db:"executed_at"`
|
|
ExitCode *int64 `db:"exit_code"`
|
|
}
|
|
|
|
sql := `
|
|
WITH all_latest_activities AS (
|
|
-- Use window function to efficiently find the latest execution per script
|
|
-- This is O(n) (a self-join approach would be O(n²))
|
|
SELECT * FROM (
|
|
SELECT
|
|
id,
|
|
host_id,
|
|
script_id,
|
|
execution_id,
|
|
created_at,
|
|
exit_code,
|
|
'completed' as source,
|
|
ROW_NUMBER() OVER (
|
|
PARTITION BY script_id
|
|
ORDER BY created_at DESC, id DESC
|
|
) AS row_num
|
|
FROM
|
|
host_script_results
|
|
WHERE
|
|
host_id = ? AND
|
|
canceled = 0
|
|
) completed_ranked
|
|
WHERE row_num = 1
|
|
|
|
UNION ALL
|
|
|
|
-- latest from upcoming_activities
|
|
SELECT * FROM (
|
|
SELECT
|
|
NULL as id,
|
|
ua.host_id,
|
|
sua.script_id,
|
|
ua.execution_id,
|
|
ua.created_at,
|
|
NULL as exit_code,
|
|
'upcoming' as source,
|
|
ROW_NUMBER() OVER (
|
|
PARTITION BY sua.script_id
|
|
ORDER BY ua.created_at DESC, ua.id DESC
|
|
) AS row_num
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.host_id = ? AND
|
|
ua.activity_type = 'script'
|
|
) upcoming_ranked
|
|
WHERE row_num = 1
|
|
)
|
|
SELECT
|
|
s.id AS script_id,
|
|
s.name,
|
|
latest.id AS hsr_id,
|
|
latest.created_at AS executed_at,
|
|
latest.execution_id,
|
|
latest.exit_code
|
|
FROM
|
|
scripts s
|
|
LEFT JOIN (
|
|
-- Pick the most recent between completed and upcoming for each script
|
|
SELECT * FROM (
|
|
SELECT
|
|
*,
|
|
ROW_NUMBER() OVER (
|
|
PARTITION BY script_id
|
|
ORDER BY
|
|
CASE WHEN source = 'upcoming' THEN 1 ELSE 2 END, -- Prefer upcoming over completed
|
|
created_at DESC,
|
|
id DESC
|
|
) AS final_rn
|
|
FROM all_latest_activities
|
|
) final_ranked
|
|
WHERE final_rn = 1
|
|
) latest
|
|
ON s.id = latest.script_id
|
|
WHERE
|
|
(latest.host_id IS NULL OR latest.host_id = ?)
|
|
AND s.global_or_team_id = ?
|
|
`
|
|
|
|
args := []any{hostID, hostID, hostID, globalOrTeamID}
|
|
if len(extension) > 0 {
|
|
args = append(args, extension)
|
|
sql += `
|
|
AND s.name LIKE ?
|
|
`
|
|
}
|
|
stmt, args := appendListOptionsWithCursorToSQL(sql, args, &opt)
|
|
|
|
var rows []*row
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &rows, stmt, args...); err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get host script details")
|
|
}
|
|
|
|
var metaData *fleet.PaginationMetadata
|
|
if opt.IncludeMetadata {
|
|
metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0}
|
|
if len(rows) > int(opt.PerPage) { //nolint:gosec // dismiss G115
|
|
metaData.HasNextResults = true
|
|
rows = rows[:len(rows)-1]
|
|
}
|
|
}
|
|
|
|
results := make([]*fleet.HostScriptDetail, 0, len(rows))
|
|
for _, r := range rows {
|
|
results = append(results, fleet.NewHostScriptDetail(hostID, r.ScriptID, r.Name, r.ExecutionID, r.ExecutedAt, r.ExitCode, r.HSRID))
|
|
}
|
|
|
|
return results, metaData, nil
|
|
}
|
|
|
|
func (ds *Datastore) BatchSetScripts(ctx context.Context, tmID *uint, scripts []*fleet.Script) ([]fleet.ScriptResponse, error) {
|
|
const loadExistingScripts = `
|
|
SELECT
|
|
name
|
|
FROM
|
|
scripts
|
|
WHERE
|
|
global_or_team_id = ? AND
|
|
name IN (?)
|
|
`
|
|
const deleteAllScriptsInTeam = `
|
|
DELETE FROM
|
|
scripts
|
|
WHERE
|
|
global_or_team_id = ?
|
|
`
|
|
const unsetAllScriptsFromPolicies = `UPDATE policies SET script_id = NULL WHERE team_id = ?`
|
|
|
|
const clearAllPendingExecutionsHSR = `DELETE FROM host_script_results WHERE
|
|
exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)`
|
|
|
|
const loadAffectedHostsAllPendingExecutionsUA = `
|
|
SELECT
|
|
DISTINCT host_id
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.activity_type = 'script'
|
|
AND ua.activated_at IS NOT NULL
|
|
AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)`
|
|
|
|
const clearAllPendingExecutionsUA = `DELETE FROM upcoming_activities
|
|
USING
|
|
upcoming_activities
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON upcoming_activities.id = sua.upcoming_activity_id
|
|
WHERE
|
|
upcoming_activities.activity_type = 'script'
|
|
AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)`
|
|
|
|
const unsetScriptsNotInListFromPolicies = `
|
|
UPDATE policies SET script_id = NULL
|
|
WHERE script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))
|
|
`
|
|
|
|
const deleteScriptsNotInList = `
|
|
DELETE FROM
|
|
scripts
|
|
WHERE
|
|
global_or_team_id = ? AND
|
|
name NOT IN (?)
|
|
`
|
|
|
|
const clearPendingExecutionsNotInListHSR = `DELETE FROM host_script_results WHERE
|
|
exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))`
|
|
|
|
const loadAffectedHostsPendingExecutionsNotInListUA = `
|
|
SELECT
|
|
DISTINCT host_id
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.activity_type = 'script'
|
|
AND ua.activated_at IS NOT NULL
|
|
AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))`
|
|
|
|
const clearPendingExecutionsNotInListUA = `DELETE FROM upcoming_activities
|
|
USING
|
|
upcoming_activities
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON upcoming_activities.id = sua.upcoming_activity_id
|
|
WHERE
|
|
upcoming_activities.activity_type = 'script'
|
|
AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))`
|
|
|
|
const insertNewOrEditedScript = `
|
|
INSERT INTO
|
|
scripts (
|
|
team_id, global_or_team_id, name, script_content_id
|
|
)
|
|
VALUES
|
|
(?, ?, ?, ?)
|
|
ON DUPLICATE KEY UPDATE
|
|
script_content_id = VALUES(script_content_id), id=LAST_INSERT_ID(id)
|
|
`
|
|
|
|
const clearPendingExecutionsWithObsoleteScriptHSR = `DELETE FROM host_script_results WHERE
|
|
exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND script_id = ? AND script_content_id != ?`
|
|
|
|
const loadAffectedHostsPendingExecutionsWithObsoleteScriptUA = `
|
|
SELECT
|
|
DISTINCT host_id
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.activity_type = 'script'
|
|
AND ua.activated_at IS NOT NULL
|
|
AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND sua.script_id = ? AND sua.script_content_id != ?`
|
|
|
|
const clearPendingExecutionsWithObsoleteScriptUA = `DELETE FROM upcoming_activities
|
|
USING
|
|
upcoming_activities
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON upcoming_activities.id = sua.upcoming_activity_id
|
|
WHERE
|
|
upcoming_activities.activity_type = 'script'
|
|
AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND sua.script_id = ? AND sua.script_content_id != ?`
|
|
|
|
const loadInsertedScripts = `SELECT id, team_id, name FROM scripts WHERE global_or_team_id = ?`
|
|
|
|
// use a team id of 0 if no-team
|
|
var globalOrTeamID uint
|
|
if tmID != nil {
|
|
globalOrTeamID = *tmID
|
|
}
|
|
|
|
// build a list of names for the incoming scripts, will keep the
|
|
// existing ones if there's a match and no change
|
|
incomingNames := make([]string, len(scripts))
|
|
// at the same time, index the incoming scripts keyed by name for ease
|
|
// of processing
|
|
incomingScripts := make(map[string]*fleet.Script, len(scripts))
|
|
for i, p := range scripts {
|
|
incomingNames[i] = p.Name
|
|
incomingScripts[p.Name] = p
|
|
}
|
|
|
|
var insertedScripts []fleet.ScriptResponse
|
|
var activateAffectedHosts []uint
|
|
|
|
if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var existingScripts []*fleet.Script
|
|
|
|
if len(incomingNames) > 0 {
|
|
// load existing scripts that match the incoming scripts by names
|
|
stmt, args, err := sqlx.In(loadExistingScripts, globalOrTeamID, incomingNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build query to load existing scripts")
|
|
}
|
|
if err := sqlx.SelectContext(ctx, tx, &existingScripts, stmt, args...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load existing scripts")
|
|
}
|
|
}
|
|
|
|
// figure out if we need to delete any scripts
|
|
keepNames := make([]string, 0, len(incomingNames))
|
|
for _, p := range existingScripts {
|
|
if newS := incomingScripts[p.Name]; newS != nil {
|
|
keepNames = append(keepNames, p.Name)
|
|
}
|
|
}
|
|
|
|
var (
|
|
scriptsStmt string
|
|
scriptsArgs []any
|
|
policiesStmt string
|
|
policiesArgs []any
|
|
executionsStmt string
|
|
executionsArgs []any
|
|
extraExecStmt string
|
|
extraExecArgs []any
|
|
err error
|
|
affectedHostIDs []uint
|
|
)
|
|
if len(keepNames) > 0 {
|
|
// delete the obsolete scripts
|
|
scriptsStmt, scriptsArgs, err = sqlx.In(deleteScriptsNotInList, globalOrTeamID, keepNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build statement to delete obsolete scripts")
|
|
}
|
|
|
|
policiesStmt, policiesArgs, err = sqlx.In(unsetScriptsNotInListFromPolicies, globalOrTeamID, keepNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build statement to unset obsolete scripts from policies")
|
|
}
|
|
|
|
executionsStmt, executionsArgs, err = sqlx.In(clearPendingExecutionsNotInListHSR, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build statement to clear pending script executions from obsolete scripts")
|
|
}
|
|
|
|
loadAffectedStmt, args, err := sqlx.In(loadAffectedHostsPendingExecutionsNotInListUA,
|
|
int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build query to load affected hosts for upcoming script executions")
|
|
}
|
|
if err := sqlx.SelectContext(ctx, tx, &affectedHostIDs, loadAffectedStmt, args...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load affected hosts for upcoming script executions")
|
|
}
|
|
|
|
extraExecStmt, extraExecArgs, err = sqlx.In(clearPendingExecutionsNotInListUA, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build statement to clear upcoming pending script executions from obsolete scripts")
|
|
}
|
|
} else {
|
|
scriptsStmt = deleteAllScriptsInTeam
|
|
scriptsArgs = []any{globalOrTeamID}
|
|
|
|
policiesStmt = unsetAllScriptsFromPolicies
|
|
policiesArgs = []any{globalOrTeamID}
|
|
|
|
executionsStmt = clearAllPendingExecutionsHSR
|
|
executionsArgs = []any{int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID}
|
|
|
|
if err := sqlx.SelectContext(ctx, tx, &affectedHostIDs,
|
|
loadAffectedHostsAllPendingExecutionsUA, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load affected hosts for upcoming script executions")
|
|
}
|
|
|
|
extraExecStmt = clearAllPendingExecutionsUA
|
|
extraExecArgs = []any{int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID}
|
|
}
|
|
if _, err := tx.ExecContext(ctx, policiesStmt, policiesArgs...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "unset obsolete scripts from policies")
|
|
}
|
|
if _, err := tx.ExecContext(ctx, executionsStmt, executionsArgs...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "clear obsolete script pending executions")
|
|
}
|
|
if _, err := tx.ExecContext(ctx, extraExecStmt, extraExecArgs...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "clear obsolete upcoming script pending executions")
|
|
}
|
|
if _, err := tx.ExecContext(ctx, scriptsStmt, scriptsArgs...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "delete obsolete scripts")
|
|
}
|
|
activateAffectedHosts = affectedHostIDs
|
|
|
|
// insert the new scripts and the ones that have changed
|
|
for _, s := range incomingScripts {
|
|
scRes, err := insertScriptContents(ctx, tx, s.ScriptContents)
|
|
if err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "inserting script contents for script with name %q", s.Name)
|
|
}
|
|
contentID, _ := scRes.LastInsertId()
|
|
insertRes, err := tx.ExecContext(ctx, insertNewOrEditedScript, tmID, globalOrTeamID, s.Name, uint(contentID)) //nolint:gosec // dismiss G115
|
|
if err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "insert new/edited script with name %q", s.Name)
|
|
}
|
|
scriptID, _ := insertRes.LastInsertId()
|
|
|
|
if _, err := tx.ExecContext(ctx, clearPendingExecutionsWithObsoleteScriptHSR, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "clear obsolete pending script executions with name %q", s.Name)
|
|
}
|
|
|
|
var affectedHosts []uint
|
|
if err := sqlx.SelectContext(ctx, tx, &affectedHosts, loadAffectedHostsPendingExecutionsWithObsoleteScriptUA,
|
|
int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "load affected hosts for upcoming script executions with name %q", s.Name)
|
|
}
|
|
activateAffectedHosts = append(activateAffectedHosts, affectedHosts...)
|
|
|
|
if _, err = tx.ExecContext(ctx, clearPendingExecutionsWithObsoleteScriptUA, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "clear obsolete upcoming pending script executions with name %q", s.Name)
|
|
}
|
|
}
|
|
|
|
if err := sqlx.SelectContext(ctx, tx, &insertedScripts, loadInsertedScripts, globalOrTeamID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load inserted scripts")
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := ds.activateNextUpcomingActivityForBatchOfHosts(ctx, activateAffectedHosts); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "activate next upcoming activity for batch of hosts")
|
|
}
|
|
|
|
return insertedScripts, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetHostLockWipeStatus(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
|
|
const stmt = `
|
|
SELECT
|
|
lock_ref,
|
|
wipe_ref,
|
|
unlock_ref,
|
|
unlock_pin,
|
|
fleet_platform
|
|
FROM
|
|
host_mdm_actions
|
|
WHERE
|
|
host_id = ?
|
|
`
|
|
|
|
var mdmActions struct {
|
|
LockRef *string `db:"lock_ref"`
|
|
WipeRef *string `db:"wipe_ref"`
|
|
UnlockRef *string `db:"unlock_ref"`
|
|
UnlockPIN *string `db:"unlock_pin"`
|
|
FleetPlatform string `db:"fleet_platform"`
|
|
}
|
|
fleetPlatform := host.FleetPlatform()
|
|
status := &fleet.HostLockWipeStatus{
|
|
HostFleetPlatform: fleetPlatform,
|
|
}
|
|
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &mdmActions, stmt, host.ID); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
// do not return a Not Found error, return the zero-value status, which
|
|
// will report the correct states.
|
|
return status, nil
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "get host lock/wipe status")
|
|
}
|
|
|
|
// if we have a fleet platform stored in host_mdm_actions, use it instead of
|
|
// the host.FleetPlatform() because the platform can be overwritten with an
|
|
// unknown OS name when a Wipe gets executed.
|
|
if mdmActions.FleetPlatform != "" {
|
|
fleetPlatform = mdmActions.FleetPlatform
|
|
status.HostFleetPlatform = fleetPlatform
|
|
}
|
|
|
|
switch fleetPlatform {
|
|
case "darwin", "ios", "ipados":
|
|
if mdmActions.UnlockPIN != nil {
|
|
status.UnlockPIN = *mdmActions.UnlockPIN
|
|
}
|
|
if mdmActions.UnlockRef != nil {
|
|
var err error
|
|
status.UnlockRequestedAt, err = time.Parse(time.DateTime, *mdmActions.UnlockRef)
|
|
if err != nil {
|
|
// if the format is unexpected but there's something in UnlockRef, just
|
|
// replace it with the current timestamp, it should still indicate that
|
|
// an unlock was requested (e.g. in case someone plays with the data
|
|
// directly in the DB and messes up the format).
|
|
status.UnlockRequestedAt = time.Now().UTC()
|
|
}
|
|
}
|
|
|
|
if mdmActions.LockRef != nil {
|
|
// the lock reference is an MDM command
|
|
cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.LockRef, host.UUID)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get lock reference")
|
|
}
|
|
status.LockMDMCommand = cmd
|
|
status.LockMDMCommandResult = cmdRes
|
|
}
|
|
|
|
if mdmActions.WipeRef != nil {
|
|
// the wipe reference is an MDM command
|
|
cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.WipeRef, host.UUID)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get wipe reference")
|
|
}
|
|
status.WipeMDMCommand = cmd
|
|
status.WipeMDMCommandResult = cmdRes
|
|
}
|
|
|
|
case "windows", "linux":
|
|
// lock and unlock references are scripts
|
|
if mdmActions.LockRef != nil {
|
|
hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.LockRef, scriptExecutionSearchOpts{IncludeCanceled: true})
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get lock reference script result")
|
|
}
|
|
status.LockScript = hsr
|
|
}
|
|
|
|
if mdmActions.UnlockRef != nil {
|
|
hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.UnlockRef, scriptExecutionSearchOpts{IncludeCanceled: true})
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get unlock reference script result")
|
|
}
|
|
status.UnlockScript = hsr
|
|
}
|
|
|
|
// wipe is an MDM command on Windows, a script on Linux
|
|
if mdmActions.WipeRef != nil {
|
|
if fleetPlatform == "windows" {
|
|
cmd, cmdRes, err := ds.getHostMDMWindowsCommand(ctx, *mdmActions.WipeRef, host.UUID)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get wipe reference")
|
|
}
|
|
status.WipeMDMCommand = cmd
|
|
status.WipeMDMCommandResult = cmdRes
|
|
} else {
|
|
hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.WipeRef, scriptExecutionSearchOpts{IncludeCanceled: true})
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get wipe reference script result")
|
|
}
|
|
status.WipeScript = hsr
|
|
}
|
|
}
|
|
}
|
|
|
|
return status, nil
|
|
}
|
|
|
|
func (ds *Datastore) getHostMDMWindowsCommand(ctx context.Context, cmdUUID, hostUUID string) (*fleet.MDMCommand, *fleet.MDMCommandResult, error) {
|
|
cmd, err := ds.getMDMCommand(ctx, ds.reader(ctx), cmdUUID)
|
|
if err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get Windows MDM command")
|
|
}
|
|
|
|
// get the MDM command result, which may be not found (indicating the command
|
|
// is pending). Note that it doesn't return ErrNoRows if not found, it
|
|
// returns success and an empty cmdRes slice.
|
|
cmdResults, err := ds.GetMDMWindowsCommandResults(ctx, cmdUUID)
|
|
if err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get Windows MDM command result")
|
|
}
|
|
|
|
// each item in the slice returned by GetMDMWindowsCommandResults is
|
|
// potentially a result for a different host, we need to find the one for
|
|
// that specific host.
|
|
var cmdRes *fleet.MDMCommandResult
|
|
for _, r := range cmdResults {
|
|
if r.HostUUID != hostUUID {
|
|
continue
|
|
}
|
|
// all statuses for Windows indicate end of processing of the command
|
|
// (there is no equivalent of "NotNow" or "Idle" as for Apple).
|
|
cmdRes = r
|
|
break
|
|
}
|
|
return cmd, cmdRes, nil
|
|
}
|
|
|
|
func (ds *Datastore) getHostMDMAppleCommand(ctx context.Context, cmdUUID, hostUUID string) (*fleet.MDMCommand, *fleet.MDMCommandResult, error) {
|
|
cmd, err := ds.getMDMCommand(ctx, ds.reader(ctx), cmdUUID)
|
|
if err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get Apple MDM command")
|
|
}
|
|
|
|
// get the MDM command result, which may be not found (indicating the command
|
|
// is pending). Note that it doesn't return ErrNoRows if not found, it
|
|
// returns success and an empty cmdRes slice.
|
|
cmdResults, err := ds.GetMDMAppleCommandResults(ctx, cmdUUID)
|
|
if err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get Apple MDM command result")
|
|
}
|
|
|
|
// each item in the slice returned by GetMDMAppleCommandResults is
|
|
// potentially a result for a different host, we need to find the one for
|
|
// that specific host.
|
|
var cmdRes *fleet.MDMCommandResult
|
|
for _, r := range cmdResults {
|
|
if r.HostUUID != hostUUID {
|
|
continue
|
|
}
|
|
if r.Status == fleet.MDMAppleStatusAcknowledged || r.Status == fleet.MDMAppleStatusError || r.Status == fleet.MDMAppleStatusCommandFormatError {
|
|
cmdRes = r
|
|
break
|
|
}
|
|
}
|
|
return cmd, cmdRes, nil
|
|
}
|
|
|
|
// LockHostViaScript will create the script execution request and update
|
|
// host_mdm_actions in a single transaction.
|
|
func (ds *Datastore) LockHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error {
|
|
var res *fleet.HostScriptResult
|
|
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
|
|
scRes, err := insertScriptContents(ctx, tx, request.ScriptContents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, _ := scRes.LastInsertId()
|
|
request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115
|
|
|
|
res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "lock host via script create execution")
|
|
}
|
|
|
|
// on duplicate we don't clear any other existing state because at this
|
|
// point in time, this is just a request to lock the host that is recorded,
|
|
// it is pending execution. The host's state should be updated to "locked"
|
|
// only when the script execution is successfully completed, and then any
|
|
// unlock or wipe references should be cleared.
|
|
const stmt = `
|
|
INSERT INTO host_mdm_actions
|
|
(
|
|
host_id,
|
|
lock_ref,
|
|
fleet_platform
|
|
)
|
|
VALUES (?,?,?)
|
|
ON DUPLICATE KEY UPDATE
|
|
lock_ref = VALUES(lock_ref)
|
|
`
|
|
|
|
_, err = tx.ExecContext(ctx, stmt,
|
|
request.HostID,
|
|
res.ExecutionID,
|
|
hostFleetPlatform,
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "lock host via script update mdm actions")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// UnlockHostViaScript will create the script execution request and update
|
|
// host_mdm_actions in a single transaction.
|
|
func (ds *Datastore) UnlockHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error {
|
|
var res *fleet.HostScriptResult
|
|
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
|
|
scRes, err := insertScriptContents(ctx, tx, request.ScriptContents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, _ := scRes.LastInsertId()
|
|
request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115
|
|
|
|
res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "unlock host via script create execution")
|
|
}
|
|
|
|
// on duplicate we don't clear any other existing state because at this
|
|
// point in time, this is just a request to unlock the host that is
|
|
// recorded, it is pending execution. The host's state should be updated to
|
|
// "unlocked" only when the script execution is successfully completed, and
|
|
// then any lock or wipe references should be cleared.
|
|
const stmt = `
|
|
INSERT INTO host_mdm_actions
|
|
(
|
|
host_id,
|
|
unlock_ref,
|
|
fleet_platform
|
|
)
|
|
VALUES (?,?,?)
|
|
ON DUPLICATE KEY UPDATE
|
|
unlock_ref = VALUES(unlock_ref),
|
|
unlock_pin = NULL
|
|
`
|
|
|
|
_, err = tx.ExecContext(ctx, stmt,
|
|
request.HostID,
|
|
res.ExecutionID,
|
|
hostFleetPlatform,
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "unlock host via script update mdm actions")
|
|
}
|
|
|
|
return err
|
|
})
|
|
}
|
|
|
|
// WipeHostViaScript creates the script execution request and updates the
|
|
// host_mdm_actions table in a single transaction.
|
|
func (ds *Datastore) WipeHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error {
|
|
var res *fleet.HostScriptResult
|
|
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
|
|
scRes, err := insertScriptContents(ctx, tx, request.ScriptContents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, _ := scRes.LastInsertId()
|
|
request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115
|
|
|
|
res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "wipe host via script create execution")
|
|
}
|
|
|
|
// on duplicate we don't clear any other existing state because at this
|
|
// point in time, this is just a request to wipe the host that is recorded,
|
|
// it is pending execution, so if it was locked, it is still locked (so the
|
|
// lock_ref info must still be there).
|
|
const stmt = `
|
|
INSERT INTO host_mdm_actions
|
|
(
|
|
host_id,
|
|
wipe_ref,
|
|
fleet_platform
|
|
)
|
|
VALUES (?,?,?)
|
|
ON DUPLICATE KEY UPDATE
|
|
wipe_ref = VALUES(wipe_ref)
|
|
`
|
|
|
|
_, err = tx.ExecContext(ctx, stmt,
|
|
request.HostID,
|
|
res.ExecutionID,
|
|
hostFleetPlatform,
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "wipe host via script update mdm actions")
|
|
}
|
|
|
|
return err
|
|
})
|
|
}
|
|
|
|
func (ds *Datastore) UnlockHostManually(ctx context.Context, hostID uint, hostFleetPlatform string, ts time.Time) error {
|
|
const stmt = `
|
|
INSERT INTO host_mdm_actions
|
|
(
|
|
host_id,
|
|
unlock_ref,
|
|
fleet_platform
|
|
)
|
|
VALUES (?, ?, ?)
|
|
ON DUPLICATE KEY UPDATE
|
|
-- do not overwrite if a value is already set
|
|
unlock_ref = IF(unlock_ref IS NULL, VALUES(unlock_ref), unlock_ref)
|
|
`
|
|
// for macOS, the unlock_ref is just the timestamp at which the user first
|
|
// requested to unlock the host. This then indicates in the host's status
|
|
// that it's pending an unlock (which requires manual intervention by
|
|
// entering a PIN on the device). The /unlock endpoint can be called multiple
|
|
// times, so we record the timestamp of the first time it was requested and
|
|
// from then on, the host is marked as "pending unlock" until the device is
|
|
// actually unlocked with the PIN. The actual unlocking happens when the
|
|
// device sends an Idle MDM request.
|
|
unlockRef := ts.Format(time.DateTime)
|
|
_, err := ds.writer(ctx).ExecContext(ctx, stmt, hostID, unlockRef, hostFleetPlatform)
|
|
return ctxerr.Wrap(ctx, err, "record manual unlock host request")
|
|
}
|
|
|
|
func buildHostLockWipeStatusUpdateStmt(refCol string, succeeded bool, joinPart string, setUnlockRef bool) string {
|
|
var alias string
|
|
|
|
stmt := `UPDATE host_mdm_actions `
|
|
if joinPart != "" {
|
|
stmt += `hma ` + joinPart
|
|
alias = "hma."
|
|
}
|
|
stmt += ` SET `
|
|
|
|
if succeeded {
|
|
switch refCol {
|
|
case "lock_ref":
|
|
// Note that this must not clear the unlock_pin, because recording the
|
|
// lock request does generate the PIN and store it there to be used by an
|
|
// eventual unlock.
|
|
if !setUnlockRef {
|
|
stmt += fmt.Sprintf("%sunlock_ref = NULL, %[1]swipe_ref = NULL", alias)
|
|
} else {
|
|
// Currently only used for Apple MDM devices.
|
|
// We set the unlock_ref to current time since the device can be unlocked any time after the lock.
|
|
// Apple MDM does not have a concept of unlock pending.
|
|
stmt += fmt.Sprintf("%sunlock_ref = '%s', %[1]swipe_ref = NULL", alias, time.Now().Format(time.DateTime))
|
|
}
|
|
case "unlock_ref":
|
|
// a successful unlock clears itself as well as the lock ref, because
|
|
// unlock is the default state so we don't need to keep its unlock_ref
|
|
// around once it's confirmed.
|
|
stmt += fmt.Sprintf("%slock_ref = NULL, %[1]sunlock_ref = NULL, %[1]sunlock_pin = NULL, %[1]swipe_ref = NULL", alias)
|
|
case "wipe_ref":
|
|
stmt += fmt.Sprintf("%slock_ref = NULL, %[1]sunlock_ref = NULL, %[1]sunlock_pin = NULL", alias)
|
|
}
|
|
} else {
|
|
// if the action failed, then we clear the reference to that action itself so
|
|
// the host stays in the previous state (it doesn't transition to the new
|
|
// state).
|
|
stmt += fmt.Sprintf("%s"+refCol+" = NULL", alias)
|
|
}
|
|
return stmt
|
|
}
|
|
|
|
func (ds *Datastore) UpdateHostLockWipeStatusFromAppleMDMResult(ctx context.Context, hostUUID, cmdUUID, requestType string, succeeded bool) error {
|
|
// a bit of MDM protocol leaking in the mysql layer, but it's either that or
|
|
// the other way around (MDM protocol would translate to database column)
|
|
var refCol string
|
|
var setUnlockRef bool
|
|
switch requestType {
|
|
case "EraseDevice":
|
|
refCol = "wipe_ref"
|
|
case "DeviceLock":
|
|
refCol = "lock_ref"
|
|
setUnlockRef = true
|
|
default:
|
|
return nil
|
|
}
|
|
return updateHostLockWipeStatusFromResultAndHostUUID(ctx, ds.writer(ctx), hostUUID, refCol, cmdUUID, succeeded, setUnlockRef)
|
|
}
|
|
|
|
func updateHostLockWipeStatusFromResultAndHostUUID(
|
|
ctx context.Context, tx sqlx.ExtContext, hostUUID, refCol, cmdUUID string, succeeded bool, setUnlockRef bool,
|
|
) error {
|
|
stmt := buildHostLockWipeStatusUpdateStmt(refCol, succeeded, `JOIN hosts h ON hma.host_id = h.id`, setUnlockRef)
|
|
stmt += ` WHERE h.uuid = ? AND hma.` + refCol + ` = ?`
|
|
_, err := tx.ExecContext(ctx, stmt, hostUUID, cmdUUID)
|
|
return ctxerr.Wrap(ctx, err, "update host lock/wipe status from result via host uuid")
|
|
}
|
|
|
|
func updateHostLockWipeStatusFromResult(ctx context.Context, tx sqlx.ExtContext, hostID uint, refCol string, succeeded bool) error {
|
|
stmt := buildHostLockWipeStatusUpdateStmt(refCol, succeeded, "", false)
|
|
stmt += ` WHERE host_id = ?`
|
|
_, err := tx.ExecContext(ctx, stmt, hostID)
|
|
return ctxerr.Wrap(ctx, err, "update host lock/wipe status from result")
|
|
}
|
|
|
|
func (ds *Datastore) updateUninstallStatusFromResult(ctx context.Context, tx sqlx.ExtContext, hostID uint, executionID string, exitCode int) error {
|
|
stmt := `
|
|
UPDATE host_software_installs SET uninstall_script_exit_code = ? WHERE execution_id = ? AND host_id = ?
|
|
`
|
|
if _, err := tx.ExecContext(ctx, stmt, exitCode, executionID, hostID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "update uninstall status from result")
|
|
}
|
|
// NOTE: no need to call activateNextUpcomingActivity here as this function
|
|
// is called from SetHostScriptExecutionResult which will call it before
|
|
// completing.
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) CleanupUnusedScriptContents(ctx context.Context) error {
|
|
deleteStmt := `
|
|
DELETE FROM
|
|
script_contents
|
|
WHERE
|
|
NOT EXISTS (
|
|
SELECT 1 FROM host_script_results WHERE script_content_id = script_contents.id)
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM scripts WHERE script_content_id = script_contents.id)
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM software_installers si
|
|
WHERE script_contents.id IN (si.install_script_content_id, si.post_install_script_content_id, si.uninstall_script_content_id)
|
|
)
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM setup_experience_scripts WHERE script_content_id = script_contents.id
|
|
)
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM script_upcoming_activities WHERE script_content_id = script_contents.id
|
|
)
|
|
`
|
|
_, err := ds.writer(ctx).ExecContext(ctx, deleteStmt)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "cleaning up unused script contents")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) getOrGenerateScriptContentsID(ctx context.Context, contents string) (uint, error) {
|
|
csum := md5ChecksumScriptContent(contents)
|
|
scriptContentsID, err := ds.optimisticGetOrInsert(ctx,
|
|
¶meterizedStmt{
|
|
Statement: `SELECT id FROM script_contents WHERE md5_checksum = UNHEX(?)`,
|
|
Args: []interface{}{csum},
|
|
},
|
|
¶meterizedStmt{
|
|
Statement: `INSERT INTO script_contents (md5_checksum, contents) VALUES (UNHEX(?), ?)`,
|
|
Args: []interface{}{csum, contents},
|
|
},
|
|
)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return scriptContentsID, nil
|
|
}
|
|
|
|
func teamIDEq(teamID1, teamID2 *uint) bool {
|
|
sameTeamNoTeam := teamID1 == nil && teamID2 == nil
|
|
sameTeamNumber := teamID1 != nil && teamID2 != nil && *teamID1 == *teamID2
|
|
return sameTeamNoTeam || sameTeamNumber
|
|
}
|
|
|
|
func (ds *Datastore) batchExecuteScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint, batchExecID string) error {
|
|
script, err := ds.Script(ctx, scriptID)
|
|
if err != nil {
|
|
return fleet.NewInvalidArgumentError("script_id", err.Error())
|
|
}
|
|
|
|
invalidHostIDPlatform := "batch-invalid-hostid"
|
|
|
|
// We need full host info to check if hosts are able to run scripts, see svc.RunHostScript
|
|
fullHosts := make([]*fleet.Host, 0, len(hostIDs))
|
|
|
|
// The execution results to be stored in the database
|
|
executions := make([]fleet.BatchExecutionHost, 0, len(fullHosts))
|
|
|
|
// Check that all hosts exist before attempting to process them
|
|
for _, hostID := range hostIDs {
|
|
host, err := ds.Host(ctx, hostID)
|
|
if err != nil {
|
|
fullHosts = append(fullHosts, &fleet.Host{
|
|
ID: hostID,
|
|
Platform: invalidHostIDPlatform,
|
|
})
|
|
continue
|
|
}
|
|
|
|
fullHosts = append(fullHosts, host)
|
|
}
|
|
|
|
if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
for _, host := range fullHosts {
|
|
// Host doesn't exist anymore
|
|
if host.Platform == invalidHostIDPlatform {
|
|
executions = append(executions, fleet.BatchExecutionHost{
|
|
HostID: host.ID,
|
|
Error: &fleet.BatchExecuteInvalidHost,
|
|
})
|
|
continue
|
|
}
|
|
|
|
// Non-orbit-enrolled host (iOS, android)
|
|
noNodeKey := host.OrbitNodeKey == nil || *host.OrbitNodeKey == ""
|
|
// Scripts disabled on host
|
|
scriptsDisabled := host.ScriptsEnabled != nil && !*host.ScriptsEnabled
|
|
|
|
if noNodeKey || scriptsDisabled {
|
|
executions = append(executions, fleet.BatchExecutionHost{
|
|
HostID: host.ID,
|
|
Error: &fleet.BatchExecuteIncompatibleFleetd,
|
|
})
|
|
continue
|
|
}
|
|
|
|
if !fleet.ValidateScriptPlatform(script.Name, host.Platform) {
|
|
executions = append(executions, fleet.BatchExecutionHost{
|
|
HostID: host.ID,
|
|
Error: &fleet.BatchExecuteIncompatiblePlatform,
|
|
})
|
|
continue
|
|
}
|
|
|
|
executionID, _, err := ds.insertNewHostScriptExecution(ctx, tx, &fleet.HostScriptRequestPayload{
|
|
HostID: host.ID,
|
|
UserID: userID,
|
|
ScriptID: &script.ID,
|
|
ScriptContentID: script.ScriptContentID,
|
|
}, false)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "queueing script for bulk execution")
|
|
}
|
|
|
|
executions = append(executions, fleet.BatchExecutionHost{
|
|
HostID: host.ID,
|
|
ExecutionID: &executionID,
|
|
})
|
|
}
|
|
|
|
_, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO batch_activities (execution_id, script_id, status, activity_type, num_targeted, started_at) VALUES (?, ?, ?, ?, ?, NOW())
|
|
ON DUPLICATE KEY UPDATE status = VALUES(status), started_at = VALUES(started_at)`,
|
|
batchExecID,
|
|
script.ID,
|
|
fleet.ScheduledBatchExecutionStarted,
|
|
fleet.BatchExecutionActivityScript,
|
|
len(hostIDs),
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "failed to insert new batch execution")
|
|
}
|
|
|
|
args := make([]map[string]any, 0, len(executions))
|
|
for _, execHost := range executions {
|
|
args = append(args, map[string]any{
|
|
"batch_id": batchExecID,
|
|
"host_id": execHost.HostID,
|
|
"host_execution_id": execHost.ExecutionID,
|
|
"error": execHost.Error,
|
|
})
|
|
}
|
|
|
|
insertStmt := `
|
|
INSERT INTO batch_activity_host_results (
|
|
batch_execution_id,
|
|
host_id,
|
|
host_execution_id,
|
|
error
|
|
) VALUES (
|
|
:batch_id,
|
|
:host_id,
|
|
:host_execution_id,
|
|
:error
|
|
) ON DUPLICATE KEY UPDATE host_execution_id = VALUES(host_execution_id), error = VALUES(error)`
|
|
|
|
if _, err := sqlx.NamedExecContext(ctx, tx, insertStmt, args); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "associating script executions with batch job")
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return fmt.Errorf("creating bulk execution order: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) BatchExecuteScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint) (string, error) {
|
|
batchExecID := uuid.New().String()
|
|
|
|
script, err := ds.Script(ctx, scriptID)
|
|
if err != nil {
|
|
return "", fleet.NewInvalidArgumentError("script_id", err.Error())
|
|
}
|
|
|
|
for _, hostID := range hostIDs {
|
|
host, err := ds.HostLite(ctx, hostID)
|
|
if err != nil {
|
|
return "", fmt.Errorf("unable to load host information for %d: %w", hostID, err)
|
|
}
|
|
|
|
if !teamIDEq(host.TeamID, script.TeamID) {
|
|
return "", ctxerr.Errorf(ctx, "all hosts must be on the same team as the script")
|
|
}
|
|
}
|
|
|
|
if err := ds.batchExecuteScript(ctx, userID, scriptID, hostIDs, batchExecID); err != nil {
|
|
return "", ctxerr.Wrap(ctx, err, "immediate batch execution")
|
|
}
|
|
|
|
return batchExecID, nil
|
|
}
|
|
|
|
func (ds *Datastore) BatchScheduleScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint, notBefore time.Time) (string, error) {
|
|
batchExecID := uuid.New().String()
|
|
|
|
const batchActivitiesStmt = `INSERT INTO batch_activities (execution_id, job_id, script_id, user_id, status, activity_type, num_targeted) VALUES (?, ?, ?, ?, ?, ?, ?)`
|
|
const batchHostsStmt = `INSERT INTO batch_activity_host_results (batch_execution_id, host_id) VALUES (:exec_id, :host_id)`
|
|
|
|
argBytes, err := json.Marshal(fleet.BatchActivityScriptJobArgs{
|
|
ExecutionID: batchExecID,
|
|
})
|
|
if err != nil {
|
|
return "", ctxerr.Wrap(ctx, err, "encooding job args")
|
|
}
|
|
|
|
if err := ds.withTx(ctx, func(tx sqlx.ExtContext) error {
|
|
job, err := ds.NewJob(ctx, &fleet.Job{
|
|
Name: fleet.BatchActivityScriptsJobName,
|
|
Args: (*json.RawMessage)(&argBytes),
|
|
State: fleet.JobStateQueued,
|
|
NotBefore: notBefore.UTC(),
|
|
})
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "creating new job")
|
|
}
|
|
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
batchActivitiesStmt,
|
|
batchExecID,
|
|
job.ID,
|
|
scriptID,
|
|
userID,
|
|
fleet.ScheduledBatchExecutionScheduled,
|
|
fleet.BatchExecutionActivityScript,
|
|
len(hostIDs),
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "inserting new batch activity")
|
|
}
|
|
|
|
args := make([]map[string]any, 0, len(hostIDs))
|
|
|
|
for _, hostID := range hostIDs {
|
|
args = append(args, map[string]any{
|
|
"exec_id": batchExecID,
|
|
"host_id": hostID,
|
|
})
|
|
}
|
|
|
|
if _, err := sqlx.NamedExecContext(ctx, tx, batchHostsStmt, args); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "inserting batch host results")
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return "", ctxerr.Wrap(ctx, err, "creating scheduled script execution")
|
|
}
|
|
|
|
return batchExecID, nil
|
|
}
|
|
|
|
func (ds *Datastore) CancelBatchScript(ctx context.Context, executionID string) error {
|
|
stmt := `
|
|
SELECT
|
|
bahr.host_execution_id,
|
|
bahr.host_id
|
|
FROM
|
|
batch_activity_host_results bahr
|
|
LEFT JOIN
|
|
host_script_results hsr ON bahr.host_execution_id = hsr.execution_id -- I think?
|
|
WHERE
|
|
bahr.batch_execution_id = ?
|
|
AND
|
|
hsr.canceled = 0
|
|
AND
|
|
hsr.exit_code IS NULL
|
|
AND
|
|
bahr.error IS NULL`
|
|
|
|
stmtSetCanceled := `
|
|
UPDATE
|
|
batch_activities ba
|
|
SET
|
|
finished_at = NOW(),
|
|
status = 'finished',
|
|
canceled = 1,
|
|
num_canceled = (SELECT COUNT(*) FROM batch_activity_host_results WHERE batch_execution_id = ba.execution_id)
|
|
WHERE
|
|
ba.execution_id = ?`
|
|
|
|
stmtCanceled := `
|
|
UPDATE
|
|
batch_activities
|
|
SET
|
|
canceled = 1
|
|
WHERE
|
|
execution_id = ?`
|
|
|
|
activity, err := ds.GetBatchActivity(ctx, executionID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "getting batch activity")
|
|
}
|
|
|
|
if activity.Status == fleet.ScheduledBatchExecutionFinished {
|
|
return nil
|
|
}
|
|
|
|
if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
// If job worker exists, mark it as complete to stop it from running
|
|
if jobID := activity.JobID; jobID != nil {
|
|
job, err := ds.GetJob(ctx, *jobID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "failed to find job associated with batch activity")
|
|
}
|
|
|
|
job.State = fleet.JobStateSuccess
|
|
|
|
if _, err := ds.updateJob(ctx, tx, *jobID, job); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "updating batch activity job")
|
|
}
|
|
}
|
|
|
|
if activity.Status == fleet.ScheduledBatchExecutionStarted {
|
|
// If the batch activity has started, we need to cancel anything in progress or queued
|
|
toCancel := []struct {
|
|
HostExecutionID string `db:"host_execution_id"`
|
|
HostID uint `db:"host_id"`
|
|
}{}
|
|
|
|
if err := sqlx.SelectContext(ctx, tx, &toCancel, stmt, executionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "selecting hosts to cancel")
|
|
}
|
|
|
|
for _, host := range toCancel {
|
|
if _, err := ds.cancelHostUpcomingActivity(ctx, tx, host.HostID, host.HostExecutionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "canceling upcoming activity")
|
|
}
|
|
}
|
|
|
|
if _, err := tx.ExecContext(ctx, stmtCanceled, executionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "setting canceled column")
|
|
}
|
|
|
|
if err := ds.markActivitiesAsCompleted(ctx, tx); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "marking job as complete and summarizing counts")
|
|
}
|
|
} else {
|
|
// The batch activity is scheduled, but not started
|
|
if _, err := tx.ExecContext(ctx, stmtSetCanceled, executionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "setting canceled host count")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "cancel batch script db transaction")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) GetBatchActivity(ctx context.Context, executionID string) (*fleet.BatchActivity, error) {
|
|
const stmt = `
|
|
SELECT
|
|
ba.id,
|
|
ba.script_id,
|
|
s.name as script_name,
|
|
ba.execution_id,
|
|
ba.user_id,
|
|
ba.job_id,
|
|
ba.status,
|
|
ba.activity_type,
|
|
ba.num_targeted,
|
|
ba.num_pending,
|
|
ba.num_ran,
|
|
ba.num_errored,
|
|
ba.num_incompatible,
|
|
ba.num_canceled,
|
|
ba.created_at,
|
|
ba.updated_at,
|
|
ba.started_at,
|
|
ba.finished_at,
|
|
ba.canceled
|
|
FROM
|
|
batch_activities ba
|
|
LEFT JOIN
|
|
scripts s ON s.id = ba.script_id
|
|
WHERE
|
|
execution_id = ?`
|
|
|
|
batchActivity := &fleet.BatchActivity{}
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), batchActivity, stmt, executionID); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "selecting batch activity")
|
|
}
|
|
|
|
return batchActivity, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetBatchActivityHostResults(ctx context.Context, executionID string) ([]*fleet.BatchActivityHostResult, error) {
|
|
const stmt = `
|
|
SELECT
|
|
id,
|
|
batch_execution_id,
|
|
host_id,
|
|
host_execution_id,
|
|
error
|
|
FROM
|
|
batch_activity_host_results
|
|
WHERE
|
|
batch_execution_id = ?`
|
|
|
|
results := []*fleet.BatchActivityHostResult{}
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, stmt, executionID); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "selecting batch activity host results")
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (ds *Datastore) RunScheduledBatchActivity(ctx context.Context, executionID string) error {
|
|
batchActivity, err := ds.GetBatchActivity(ctx, executionID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "getting batch activity")
|
|
}
|
|
|
|
if batchActivity.Status != fleet.ScheduledBatchExecutionScheduled {
|
|
return ctxerr.New(ctx, "batch job has already been started")
|
|
}
|
|
|
|
if batchActivity.Canceled {
|
|
return ctxerr.New(ctx, "batch job was canceled")
|
|
}
|
|
|
|
if batchActivity.ScriptID == nil {
|
|
return ctxerr.New(ctx, "no script ID present in batch activity")
|
|
}
|
|
|
|
script, err := ds.Script(ctx, *batchActivity.ScriptID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "could not get script")
|
|
}
|
|
|
|
results, err := ds.GetBatchActivityHostResults(ctx, executionID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "getting batch activity host results")
|
|
}
|
|
|
|
hostIDs := []uint{}
|
|
for _, result := range results {
|
|
hostIDs = append(hostIDs, result.HostID)
|
|
}
|
|
|
|
if err := ds.batchExecuteScript(ctx, batchActivity.UserID, script.ID, hostIDs, batchActivity.BatchExecutionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "scheduled batch script execution")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Deprecated; will be removed in favor of ListBatchScriptExecutions when the batch script details page is ready.
|
|
func (ds *Datastore) BatchExecuteSummary(ctx context.Context, executionID string) (*fleet.BatchActivity, error) {
|
|
stmtExecutions := `
|
|
SELECT
|
|
COUNT(*) as num_targeted,
|
|
COUNT(bsehr.error) as num_did_not_run,
|
|
COUNT(CASE WHEN hsr.exit_code = 0 THEN 1 END) as num_succeeded,
|
|
COUNT(CASE WHEN hsr.exit_code <> 0 THEN 1 END) as num_failed,
|
|
COUNT(CASE WHEN hsr.canceled = 1 AND hsr.exit_code IS NULL THEN 1 END) as num_cancelled
|
|
FROM
|
|
batch_activity_host_results bsehr
|
|
LEFT JOIN
|
|
host_script_results hsr
|
|
ON bsehr.host_execution_id = hsr.execution_id
|
|
WHERE
|
|
bsehr.batch_execution_id = ?`
|
|
|
|
stmtScriptDetails := `
|
|
SELECT
|
|
script_id,
|
|
s.name as script_name,
|
|
s.team_id as team_id,
|
|
bse.created_at as created_at
|
|
FROM
|
|
batch_activities bse
|
|
JOIN
|
|
scripts s
|
|
ON bse.script_id = s.id
|
|
WHERE
|
|
bse.execution_id = ?`
|
|
|
|
var summary fleet.BatchActivity
|
|
var temp_summary struct {
|
|
NumTargeted uint `db:"num_targeted"`
|
|
NumDidNotRun uint `db:"num_did_not_run"`
|
|
NumSucceeded uint `db:"num_succeeded"`
|
|
NumFailed uint `db:"num_failed"`
|
|
NumCancelled uint `db:"num_cancelled"`
|
|
}
|
|
// Fill out the execution details
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &temp_summary, stmtExecutions, executionID); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary")
|
|
}
|
|
|
|
summary.NumTargeted = &temp_summary.NumTargeted
|
|
// NumRan is the number of hosts that actually ran the script successfully.
|
|
summary.NumRan = &temp_summary.NumSucceeded
|
|
// NumErrored is the number of hosts that errored out, which includes
|
|
// both failed and did not run.
|
|
summary.NumErrored = ptr.Uint(temp_summary.NumFailed + temp_summary.NumDidNotRun)
|
|
// NumFailed is the number of hosts that were canceled before execution.
|
|
summary.NumCanceled = &temp_summary.NumCancelled
|
|
// NumPending is the number of hosts that are pending execution.
|
|
summary.NumPending = ptr.Uint(temp_summary.NumTargeted - (temp_summary.NumSucceeded + temp_summary.NumFailed + temp_summary.NumDidNotRun + temp_summary.NumCancelled))
|
|
|
|
// Fill out the script details
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &summary, stmtScriptDetails, executionID); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "selecting script information for bulk execution summary")
|
|
}
|
|
|
|
if summary.TeamID == nil {
|
|
summary.TeamID = ptr.Uint(0)
|
|
}
|
|
|
|
return &summary, nil
|
|
}
|
|
|
|
func (ds *Datastore) ListBatchScriptExecutions(ctx context.Context, filter fleet.BatchExecutionStatusFilter) ([]fleet.BatchActivity, error) {
|
|
stmtExecutions := `
|
|
SELECT *
|
|
FROM (
|
|
-- If batch is finished, get the cached host result counts
|
|
SELECT
|
|
COALESCE(ba.num_targeted, 0) AS num_targeted,
|
|
COALESCE(ba.num_incompatible, 0) AS num_incompatible,
|
|
COALESCE(ba.num_ran, 0) AS num_ran,
|
|
COALESCE(ba.num_errored, 0) AS num_errored,
|
|
COALESCE(ba.num_canceled, 0) AS num_canceled,
|
|
COALESCE(ba.num_pending, 0) AS num_pending,
|
|
ba.execution_id,
|
|
ba.script_id,
|
|
ba.status,
|
|
ba.canceled,
|
|
ba.finished_at,
|
|
ba.started_at,
|
|
s.name AS script_name,
|
|
s.global_or_team_id AS team_id,
|
|
ba.created_at AS created_at,
|
|
j.not_before AS not_before,
|
|
ba.id AS id
|
|
FROM batch_activities ba
|
|
JOIN scripts s ON ba.script_id = s.id
|
|
LEFT JOIN jobs j ON j.id = ba.job_id
|
|
WHERE ( %s ) AND ba.status = 'finished'
|
|
|
|
UNION ALL
|
|
|
|
-- If batch is not finished, calculate the host result counts live.
|
|
SELECT
|
|
COUNT(bahr.host_id) AS num_targeted,
|
|
COUNT(bahr.error) AS num_incompatible,
|
|
COUNT(IF(hsr.exit_code = 0, 1, NULL)) AS num_ran,
|
|
COUNT(IF(hsr.exit_code <> 0, 1, NULL)) AS num_errored,
|
|
COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba.canceled = 1), 1, NULL)) AS num_cancelled,
|
|
(
|
|
COUNT(bahr.host_id)
|
|
- COUNT(bahr.error)
|
|
- COUNT(IF(hsr.exit_code = 0, 1, NULL))
|
|
- COUNT(IF(hsr.exit_code <> 0, 1, NULL))
|
|
- COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba.canceled = 1), 1, NULL))
|
|
) AS num_pending,
|
|
ba.execution_id,
|
|
ba.script_id,
|
|
ba.status,
|
|
ba.canceled,
|
|
ba.finished_at,
|
|
ba.started_at,
|
|
s.name AS script_name,
|
|
s.global_or_team_id AS team_id,
|
|
ba.created_at AS created_at,
|
|
j.not_before AS not_before,
|
|
ba.id AS id
|
|
FROM batch_activities ba
|
|
LEFT JOIN batch_activity_host_results bahr
|
|
ON ba.execution_id = bahr.batch_execution_id
|
|
LEFT JOIN host_script_results hsr
|
|
ON bahr.host_execution_id = hsr.execution_id
|
|
JOIN scripts s
|
|
ON ba.script_id = s.id
|
|
LEFT JOIN jobs j
|
|
ON j.id = ba.job_id
|
|
WHERE ( %s ) AND ba.status <> 'finished'
|
|
GROUP BY ba.id
|
|
) AS u
|
|
ORDER BY
|
|
%s
|
|
LIMIT %d OFFSET %d
|
|
`
|
|
limit := 10
|
|
offset := 0
|
|
args := []any{}
|
|
orderBy := []string{"u.created_at DESC", "u.id DESC"}
|
|
whereClauses := make([]string, 0, 2)
|
|
// If an execution ID is provided, use it to filter the results.
|
|
if filter.ExecutionID != nil && *filter.ExecutionID != "" {
|
|
whereClauses = append(whereClauses, "ba.execution_id = ?")
|
|
args = append(args, *filter.ExecutionID)
|
|
} else {
|
|
// Otherwise filter by status and/or team ID.
|
|
if filter.Status != nil && *filter.Status != "" {
|
|
whereClauses = append(whereClauses, "ba.status = ?")
|
|
args = append(args, *filter.Status)
|
|
switch *filter.Status {
|
|
case string(fleet.ScheduledBatchExecutionScheduled):
|
|
orderBy = append([]string{"u.not_before ASC"}, orderBy...)
|
|
case string(fleet.ScheduledBatchExecutionStarted):
|
|
orderBy = append([]string{"u.started_at DESC"}, orderBy...)
|
|
case string(fleet.ScheduledBatchExecutionFinished):
|
|
orderBy = append([]string{"u.finished_at DESC"}, orderBy...)
|
|
default:
|
|
// no additional ordering
|
|
}
|
|
}
|
|
if filter.TeamID != nil {
|
|
whereClauses = append(whereClauses, "s.global_or_team_id = ?")
|
|
args = append(args, *filter.TeamID)
|
|
}
|
|
}
|
|
|
|
// Double up the args to use them in both WHERE clauses.
|
|
args = append(args, args...)
|
|
|
|
// Use pagination parameters if provided.
|
|
if filter.Limit != nil {
|
|
limit = int(*filter.Limit) //nolint:gosec // dismiss G115
|
|
}
|
|
if filter.Offset != nil {
|
|
offset = int(*filter.Offset) //nolint:gosec // dismiss G115
|
|
}
|
|
where := strings.Join(whereClauses, " AND ")
|
|
stmtExecutions = fmt.Sprintf(stmtExecutions, where, where, strings.Join(orderBy, ", "), limit, offset)
|
|
var summary []fleet.BatchActivity
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &summary, stmtExecutions, args...); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary")
|
|
}
|
|
|
|
return summary, nil
|
|
}
|
|
|
|
func (ds *Datastore) CountBatchScriptExecutions(ctx context.Context, filter fleet.BatchExecutionStatusFilter) (int64, error) {
|
|
stmtExecutions := `
|
|
SELECT
|
|
COUNT(*)
|
|
FROM
|
|
batch_activities ba
|
|
JOIN
|
|
scripts s
|
|
ON ba.script_id = s.id
|
|
WHERE
|
|
%s
|
|
`
|
|
args := []any{}
|
|
whereClauses := make([]string, 0, 2)
|
|
if filter.Status != nil && *filter.Status != "" {
|
|
whereClauses = append(whereClauses, "ba.status = ?")
|
|
args = append(args, *filter.Status)
|
|
}
|
|
if filter.TeamID != nil {
|
|
whereClauses = append(whereClauses, "s.global_or_team_id = ?")
|
|
args = append(args, *filter.TeamID)
|
|
}
|
|
where := strings.Join(whereClauses, " AND ")
|
|
stmtExecutions = fmt.Sprintf(stmtExecutions, where)
|
|
|
|
var count int64
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &count, stmtExecutions, args...); err != nil {
|
|
return 0, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary")
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
func (ds *Datastore) markActivitiesAsCompleted(ctx context.Context, tx sqlx.ExtContext) error {
|
|
const stmt = `
|
|
UPDATE batch_activities AS ba
|
|
JOIN (
|
|
SELECT
|
|
ba2.id AS batch_id,
|
|
COUNT(bahr.host_id) AS num_targeted,
|
|
COUNT(bahr.error) AS num_incompatible,
|
|
COUNT(IF(hsr.exit_code = 0, 1, NULL)) AS num_ran,
|
|
COUNT(IF(hsr.exit_code <> 0, 1, NULL)) AS num_errored,
|
|
COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba2.canceled = 1), 1, NULL)) AS num_canceled
|
|
FROM batch_activities AS ba2
|
|
LEFT JOIN batch_activity_host_results AS bahr
|
|
ON ba2.execution_id = bahr.batch_execution_id
|
|
LEFT JOIN host_script_results AS hsr
|
|
ON bahr.host_execution_id = hsr.execution_id
|
|
WHERE ba2.status = 'started'
|
|
GROUP BY ba2.id
|
|
HAVING (num_incompatible + num_ran + num_errored + num_canceled) >= num_targeted
|
|
) AS agg
|
|
ON agg.batch_id = ba.id
|
|
SET
|
|
ba.status = 'finished',
|
|
ba.finished_at = NOW(),
|
|
ba.num_targeted = agg.num_targeted,
|
|
ba.num_incompatible = agg.num_incompatible,
|
|
ba.num_ran = agg.num_ran,
|
|
ba.num_errored = agg.num_errored,
|
|
ba.num_canceled = agg.num_canceled,
|
|
ba.num_pending = 0
|
|
WHERE ba.status = 'started';
|
|
`
|
|
// TODO -- use `RETURNING` to return the IDs of the updated activities?
|
|
_, err := tx.ExecContext(ctx, stmt)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "marking activities as completed")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) MarkActivitiesAsCompleted(ctx context.Context) error {
|
|
return ds.markActivitiesAsCompleted(ctx, ds.writer(ctx))
|
|
}
|