fleet/server/datastore/mysql/scripts.go
Scott Gress 6c659050c0
Fix Orbit-canceled script runs being counted as "pending" (#33300)
<!-- Add the related story/sub-task/bug number, like Resolves #123, or
remove if NA -->
**Related issue:** Resolves #32220

# Details

This PR fixes an issue where hosts whose running scripts were canceled
by Orbit (e.g. due to timing out) were reported as being still "pending"
on the batch script details view. This was due to our only counting runs
as errored if the error code was > 0, and ignoring negative error codes
(which is what Orbit uses for this case).

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

- [X] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)

## Testing

- [X] Added/updated automated tests
Changed a couple of places where we were using `1` for an error code to
`-1`
- [X] Where appropriate, [automated tests simulate multiple hosts and
test for host
isolation](https://github.com/fleetdm/fleet/blob/main/docs/Contributing/reference/patterns-backend.md#unit-testing)
(updates to one hosts's records do not affect another)

- [X] QA'd all new/changed functionality manually

For unreleased bug fixes in a release candidate, one of:

- [X] Confirmed that the fix is not expected to adversely impact load
test results
- [X] Alerted the release DRI if additional load testing is needed
2025-09-23 12:22:28 -05:00

2537 lines
79 KiB
Go

package mysql
import (
"context"
"crypto/md5" //nolint:gosec
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"unicode/utf8"
constants "github.com/fleetdm/fleet/v4/pkg/scripts"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
"github.com/go-kit/log/level"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
)
func (ds *Datastore) NewHostScriptExecutionRequest(ctx context.Context, request *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) {
var res *fleet.HostScriptResult
return res, ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
var err error
if request.ScriptContentID == 0 {
// then we are doing a sync execution, so create the contents first
scRes, err := insertScriptContents(ctx, tx, request.ScriptContents)
if err != nil {
return err
}
id, _ := scRes.LastInsertId()
request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115
}
res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, false)
return err
})
}
func (ds *Datastore) newHostScriptExecutionRequest(ctx context.Context, tx sqlx.ExtContext, request *fleet.HostScriptRequestPayload, isInternal bool) (*fleet.HostScriptResult, error) {
const (
getStmt = `
SELECT
ua.id, ua.host_id, ua.execution_id, ua.created_at, sua.script_id, sua.policy_id, ua.user_id,
payload->'$.sync_request' AS sync_request,
sc.contents as script_contents, sua.setup_experience_script_id
FROM
upcoming_activities ua
INNER JOIN script_upcoming_activities sua
ON ua.id = sua.upcoming_activity_id
INNER JOIN script_contents sc
ON sua.script_content_id = sc.id
WHERE
ua.id = ?
`
)
_, activityID, err := ds.insertNewHostScriptExecution(ctx, tx, request, isInternal)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "inserting new script execution request")
}
var script fleet.HostScriptResult
err = sqlx.GetContext(ctx, tx, &script, getStmt, activityID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "getting the created host script activity to return")
}
return &script, nil
}
func (ds *Datastore) insertNewHostScriptExecution(ctx context.Context, tx sqlx.ExtContext, request *fleet.HostScriptRequestPayload, isInternal bool) (string, int64, error) {
const (
insUAStmt = `
INSERT INTO upcoming_activities
(host_id, priority, user_id, fleet_initiated, activity_type, execution_id, payload)
VALUES
(?, ?, ?, ?, 'script', ?,
JSON_OBJECT(
'sync_request', ?,
'is_internal', ?,
'user', (SELECT JSON_OBJECT('name', name, 'email', email, 'gravatar_url', gravatar_url) FROM users WHERE id = ?)
)
)`
insSUAStmt = `
INSERT INTO script_upcoming_activities
(upcoming_activity_id, script_id, script_content_id, policy_id, setup_experience_script_id)
VALUES
(?, ?, ?, ?, ?)
`
)
execID := uuid.New().String()
result, err := tx.ExecContext(ctx, insUAStmt,
request.HostID,
request.Priority(),
request.UserID,
request.PolicyID != nil, // fleet-initiated if request is via a policy failure
execID,
request.SyncRequest,
isInternal,
request.UserID,
)
if err != nil {
return "", 0, ctxerr.Wrap(ctx, err, "new script upcoming activity")
}
activityID, _ := result.LastInsertId()
_, err = tx.ExecContext(ctx, insSUAStmt,
activityID,
request.ScriptID,
request.ScriptContentID,
request.PolicyID,
request.SetupExperienceScriptID,
)
if err != nil {
return "", 0, ctxerr.Wrap(ctx, err, "new join script upcoming activity")
}
if _, err := ds.activateNextUpcomingActivity(ctx, tx, request.HostID, ""); err != nil {
return "", 0, ctxerr.Wrap(ctx, err, "activate next activity")
}
return execID, activityID, nil
}
func truncateScriptResult(output string) string {
const maxOutputRuneLen = 10000
if len(output) > utf8.UTFMax*maxOutputRuneLen {
// truncate the bytes as we know the output is too long, no point
// converting more bytes than needed to runes.
output = output[len(output)-(utf8.UTFMax*maxOutputRuneLen):]
}
if utf8.RuneCountInString(output) > maxOutputRuneLen {
outputRunes := []rune(output)
output = string(outputRunes[len(outputRunes)-maxOutputRuneLen:])
}
return output
}
func (ds *Datastore) SetHostScriptExecutionResult(ctx context.Context, result *fleet.HostScriptResultPayload) (*fleet.HostScriptResult,
string, error,
) {
const resultExistsStmt = `
SELECT
1
FROM
host_script_results
WHERE
host_id = ? AND
execution_id = ? AND
exit_code IS NOT NULL
`
const updStmt = `
UPDATE host_script_results SET
output = ?,
runtime = ?,
exit_code = ?,
timeout = ?
WHERE
host_id = ? AND
execution_id = ?`
const hostMDMActionsStmt = `
SELECT 'uninstall' AS action
FROM
host_software_installs
WHERE
execution_id = :execution_id AND host_id = :host_id
UNION -- host_mdm_actions query (and thus row in union) must be last to avoid #25144
SELECT
CASE
WHEN lock_ref = :execution_id THEN 'lock_ref'
WHEN unlock_ref = :execution_id THEN 'unlock_ref'
WHEN wipe_ref = :execution_id THEN 'wipe_ref'
ELSE ''
END AS action
FROM
host_mdm_actions
WHERE
host_id = :host_id
`
output := truncateScriptResult(result.Output)
var hsr *fleet.HostScriptResult
var action string
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
var resultExists bool
err := sqlx.GetContext(ctx, tx, &resultExists, resultExistsStmt, result.HostID, result.ExecutionID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return ctxerr.Wrap(ctx, err, "check if host script result exists")
}
if resultExists {
level.Debug(ds.logger).Log("msg", "duplicate script execution result sent, will be ignored (original result is preserved)",
"host_id", result.HostID,
"execution_id", result.ExecutionID,
)
// still do the activate next activity to ensure progress as there was
// an unexpected flow if we get here.
if _, err := ds.activateNextUpcomingActivity(ctx, tx, result.HostID, result.ExecutionID); err != nil {
return ctxerr.Wrap(ctx, err, "activate next activity")
}
// succeed but leave hsr nil
return nil
}
res, err := tx.ExecContext(ctx, updStmt,
output,
result.Runtime,
// Windows error codes are signed 32-bit integers, but are
// returned as unsigned integers by the windows API. The
// software that receives them is responsible for casting
// it to a 32-bit signed integer.
// See /orbit/pkg/scripts/exec_windows.go
int32(result.ExitCode), //nolint:gosec // dismiss G115
result.Timeout,
result.HostID,
result.ExecutionID,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "update host script result")
}
if n, _ := res.RowsAffected(); n > 0 {
// it did update, so return the updated result
hsr, err = ds.getHostScriptExecutionResultDB(ctx, tx, result.ExecutionID, scriptExecutionSearchOpts{IncludeCanceled: true})
if err != nil {
return ctxerr.Wrap(ctx, err, "load updated host script result")
}
// look up if that script was a lock/unlock/wipe/uninstall script for that host,
// and if so update the host_mdm_actions table accordingly.
namedArgs := map[string]any{
"host_id": result.HostID,
"execution_id": result.ExecutionID,
}
stmt, args, err := sqlx.Named(hostMDMActionsStmt, namedArgs)
if err != nil {
return ctxerr.Wrap(ctx, err, "build named query for host mdm actions")
}
err = sqlx.GetContext(ctx, tx, &action, stmt, args...)
if err != nil && !errors.Is(err, sql.ErrNoRows) { // ignore ErrNoRows, refCol will be empty
return ctxerr.Wrap(ctx, err, "lookup host script corresponding mdm action")
}
switch action {
case "":
// do nothing
case "uninstall":
err = ds.updateUninstallStatusFromResult(ctx, tx, result.HostID, result.ExecutionID, result.ExitCode)
if err != nil {
return ctxerr.Wrap(ctx, err, "update host uninstall action based on script result")
}
default: // lock/unlock/wipe
err = updateHostLockWipeStatusFromResult(ctx, tx, result.HostID, action, result.ExitCode == 0)
if err != nil {
return ctxerr.Wrap(ctx, err, "update host mdm action based on script result")
}
}
}
if _, err := ds.activateNextUpcomingActivity(ctx, tx, result.HostID, result.ExecutionID); err != nil {
return ctxerr.Wrap(ctx, err, "activate next activity")
}
return nil
})
if err != nil {
return nil, "", err
}
return hsr, action, nil
}
func (ds *Datastore) ListPendingHostScriptExecutions(ctx context.Context, hostID uint, onlyShowInternal bool) ([]*fleet.HostScriptResult, error) {
return ds.listUpcomingHostScriptExecutions(ctx, hostID, onlyShowInternal, false)
}
func (ds *Datastore) ListReadyToExecuteScriptsForHost(ctx context.Context, hostID uint, onlyShowInternal bool) ([]*fleet.HostScriptResult, error) {
return ds.listUpcomingHostScriptExecutions(ctx, hostID, onlyShowInternal, true)
}
func (ds *Datastore) listUpcomingHostScriptExecutions(ctx context.Context, hostID uint, onlyShowInternal, onlyReadyToExecute bool) ([]*fleet.HostScriptResult, error) {
extraWhere := ""
if onlyShowInternal {
// software_uninstalls are implicitly internal
extraWhere = " AND COALESCE(ua.payload->'$.is_internal', 1) = 1"
}
if onlyReadyToExecute {
extraWhere += " AND ua.activated_at IS NOT NULL"
}
// this selects software uninstalls too as they run as scripts
listStmt := fmt.Sprintf(`
SELECT
id,
host_id,
execution_id,
script_id,
created_at
FROM (
SELECT
ua.id,
ua.host_id,
ua.execution_id,
sua.script_id,
ua.priority,
ua.created_at,
IF(ua.activated_at IS NULL, 0, 1) AS topmost
FROM
upcoming_activities ua
-- left join because software_uninstall has no script join
LEFT JOIN script_upcoming_activities sua
ON ua.id = sua.upcoming_activity_id
WHERE
ua.host_id = ? AND
ua.activity_type IN ('script', 'software_uninstall')
%s
ORDER BY topmost DESC, priority DESC, created_at ASC) t`, extraWhere)
var results []*fleet.HostScriptResult
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, listStmt, hostID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "list pending host script executions")
}
return results, nil
}
func (ds *Datastore) IsExecutionPendingForHost(ctx context.Context, hostID uint, scriptID uint) (bool, error) {
const getStmt = `
SELECT
1
FROM
upcoming_activities ua
INNER JOIN script_upcoming_activities sua
ON ua.id = sua.upcoming_activity_id
WHERE
ua.host_id = ? AND
ua.activity_type = 'script' AND
sua.script_id = ?
`
var results []*uint
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, getStmt, hostID, scriptID); err != nil {
return false, ctxerr.Wrap(ctx, err, "is execution pending for host")
}
return len(results) > 0, nil
}
type scriptExecutionSearchOpts struct {
IncludeCanceled bool
UninstallHostID uint
}
func (ds *Datastore) GetHostScriptExecutionResult(ctx context.Context, execID string) (*fleet.HostScriptResult, error) {
return ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), execID, scriptExecutionSearchOpts{})
}
func (ds *Datastore) GetSelfServiceUninstallScriptExecutionResult(ctx context.Context, execID string, hostID uint) (*fleet.HostScriptResult, error) {
return ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), execID, scriptExecutionSearchOpts{UninstallHostID: hostID})
}
func (ds *Datastore) getHostScriptExecutionResultDB(ctx context.Context, q sqlx.QueryerContext, execID string, opts scriptExecutionSearchOpts) (*fleet.HostScriptResult, error) {
var activeParams []any
canceledCondition := ""
if !opts.IncludeCanceled {
canceledCondition = " AND hsr.canceled = 0"
}
uninstallCondition := ""
if opts.UninstallHostID > 0 {
uninstallCondition = `JOIN host_software_installs hsi ON hsi.execution_id = hsr.execution_id
AND hsi.uninstall = TRUE AND hsr.host_id = ?`
activeParams = append(activeParams, opts.UninstallHostID)
}
activeParams = append(activeParams, execID)
getActiveStmt := fmt.Sprintf(`
SELECT
hsr.id,
hsr.host_id,
hsr.execution_id,
sc.contents as script_contents,
hsr.script_id,
hsr.policy_id,
hsr.output,
hsr.runtime,
hsr.exit_code,
hsr.timeout,
hsr.created_at,
hsr.user_id,
hsr.sync_request,
hsr.host_deleted_at,
hsr.setup_experience_script_id,
hsr.canceled,
bahr.batch_execution_id
FROM
host_script_results hsr
LEFT JOIN
batch_activity_host_results bahr ON hsr.execution_id = bahr.host_execution_id
JOIN
script_contents sc
%s
WHERE
hsr.execution_id = ? AND
hsr.script_content_id = sc.id
%s
`, uninstallCondition, canceledCondition)
// We don't include upcoming uninstall script executions in results (different activity type, and they're blank anyway)
const getUpcomingStmt = `
SELECT
0 as id,
ua.host_id,
ua.execution_id,
sc.contents as script_contents,
sua.script_id,
sua.policy_id,
'' as output,
0 as runtime,
NULL as exit_code,
NULL as timeout,
ua.created_at,
ua.user_id,
COALESCE(ua.payload->'$.sync_request', 0) as sync_request,
NULL as host_deleted_at,
sua.setup_experience_script_id,
0 as canceled
FROM
upcoming_activities ua
INNER JOIN script_upcoming_activities sua
ON ua.id = sua.upcoming_activity_id
INNER JOIN
script_contents sc
ON sua.script_content_id = sc.id
WHERE
ua.execution_id = ? AND
ua.activity_type = 'script'
`
var result fleet.HostScriptResult
if err := sqlx.GetContext(ctx, q, &result, getActiveStmt, activeParams...); err != nil {
if errors.Is(err, sql.ErrNoRows) {
// try with upcoming activities
err = sqlx.GetContext(ctx, q, &result, getUpcomingStmt, execID)
if errors.Is(err, sql.ErrNoRows) {
return nil, ctxerr.Wrap(ctx, notFound("HostScriptResult").WithName(execID))
}
}
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get host script result")
}
}
return &result, nil
}
func (ds *Datastore) NewScript(ctx context.Context, script *fleet.Script) (*fleet.Script, error) {
var res sql.Result
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
var err error
// first insert script contents
scRes, err := insertScriptContents(ctx, tx, script.ScriptContents)
if err != nil {
return err
}
id, _ := scRes.LastInsertId()
// then create the script entity
res, err = insertScript(ctx, tx, script, uint(id)) //nolint:gosec // dismiss G115
return err
})
if err != nil {
return nil, err
}
id, _ := res.LastInsertId()
return ds.getScriptDB(ctx, ds.writer(ctx), uint(id)) //nolint:gosec // dismiss G115
}
func (ds *Datastore) UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error) {
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
// Get the current script_content_id
var oldContentID int64
getCurrentStmt := `SELECT script_content_id FROM scripts WHERE id = ?`
err := sqlx.GetContext(ctx, tx, &oldContentID, getCurrentStmt, scriptID)
if err != nil {
return ctxerr.Wrap(ctx, err, "getting current script content id")
}
// Insert or get existing content (insertScriptContents handles deduplication)
scRes, err := insertScriptContents(ctx, tx, scriptContents)
if err != nil {
return ctxerr.Wrap(ctx, err, "inserting/getting script contents")
}
newContentID, _ := scRes.LastInsertId()
// Update the script to point to the new content
if newContentID != oldContentID {
updateStmt := `
UPDATE scripts
SET script_content_id = ?
WHERE id = ?
`
_, err = tx.ExecContext(ctx, updateStmt, newContentID, scriptID)
if err != nil {
return ctxerr.Wrap(ctx, err, "updating script content reference")
}
// Try to clean up the old content if no longer used
// Don't fail the transaction if cleanup fails; just log it
if err := ds.cleanupScriptContent(ctx, tx, uint(oldContentID)); err != nil { //nolint:gosec
level.Error(ds.logger).Log("msg", "failed to cleanup orphaned script content",
"script_id", scriptID, "old_content_id", oldContentID, "err", err)
ctxerr.Handle(ctx, err)
}
} else {
// Just update the timestamp
_, err = tx.ExecContext(ctx, "UPDATE scripts SET updated_at = NOW() WHERE id = ?", scriptID)
if err != nil {
return ctxerr.Wrap(ctx, err, "updating script updated_at time")
}
}
// Cancel pending executions
if err := ds.cancelUpcomingScriptActivities(ctx, tx, scriptID); err != nil {
return ctxerr.Wrap(ctx, err, "canceling upcoming script executions")
}
return nil
})
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "updating script contents")
}
return ds.Script(ctx, scriptID)
}
func (ds *Datastore) cancelUpcomingScriptActivities(ctx context.Context, db sqlx.ExtContext, scriptID uint) error {
const stmt = `
SELECT
ua.execution_id,
ua.host_id
FROM
script_upcoming_activities sua
INNER JOIN
upcoming_activities ua ON ua.id = sua.upcoming_activity_id
WHERE
sua.script_id = ?
`
var upcomingExecutions []struct {
ExecutionID string `db:"execution_id"`
HostID uint `db:"host_id"`
}
if err := sqlx.SelectContext(ctx, db, &upcomingExecutions, stmt, scriptID); err != nil {
return ctxerr.Wrap(ctx, err, "selecting upcoming script executions")
}
for _, upcomingExecution := range upcomingExecutions {
if _, err := ds.cancelHostUpcomingActivity(ctx, db, upcomingExecution.HostID, upcomingExecution.ExecutionID); err != nil {
return ctxerr.Wrap(ctx, err, "canceling upcoming activity")
}
}
return nil
}
func insertScript(ctx context.Context, tx sqlx.ExtContext, script *fleet.Script, scriptContentsID uint) (sql.Result, error) {
const insertStmt = `
INSERT INTO
scripts (
team_id, global_or_team_id, name, script_content_id
)
VALUES
(?, ?, ?, ?)
`
var globalOrTeamID uint
if script.TeamID != nil {
globalOrTeamID = *script.TeamID
}
res, err := tx.ExecContext(ctx, insertStmt,
script.TeamID, globalOrTeamID, script.Name, scriptContentsID)
if err != nil {
if IsDuplicate(err) {
// name already exists for this team/global
err = alreadyExists("Script", script.Name)
} else if isChildForeignKeyError(err) {
// team does not exist
err = foreignKey("scripts", fmt.Sprintf("team_id=%v", script.TeamID))
}
return nil, ctxerr.Wrap(ctx, err, "insert script")
}
return res, nil
}
func insertScriptContents(ctx context.Context, tx sqlx.ExtContext, contents string) (sql.Result, error) {
const insertStmt = `
INSERT INTO
script_contents (
md5_checksum, contents
)
VALUES (UNHEX(?),?)
ON DUPLICATE KEY UPDATE
id=LAST_INSERT_ID(id)
`
md5Checksum := md5ChecksumScriptContent(contents)
res, err := tx.ExecContext(ctx, insertStmt, md5Checksum, contents)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "insert script contents")
}
return res, nil
}
func md5ChecksumScriptContent(s string) string {
return md5ChecksumBytes([]byte(s))
}
func md5ChecksumBytes(b []byte) string {
rawChecksum := md5.Sum(b) //nolint:gosec
return strings.ToUpper(hex.EncodeToString(rawChecksum[:]))
}
func (ds *Datastore) cleanupScriptContent(ctx context.Context, tx sqlx.ExtContext, contentID uint) error {
// Check if this content is still being used anywhere
var usageCount int
stmt := `
SELECT COUNT(*) FROM (
SELECT 1 FROM scripts WHERE script_content_id = ?
UNION ALL
SELECT 1 FROM setup_experience_scripts WHERE script_content_id = ?
UNION ALL
SELECT 1 FROM software_installers WHERE
install_script_content_id = ?
OR uninstall_script_content_id = ?
OR post_install_script_content_id = ?
UNION ALL
SELECT 1 FROM script_upcoming_activities WHERE script_content_id = ?
UNION ALL
SELECT 1 FROM host_script_results WHERE script_content_id = ?
) t
`
err := sqlx.GetContext(ctx, tx, &usageCount, stmt,
contentID, contentID, contentID, contentID, contentID, contentID, contentID)
if err != nil {
return ctxerr.Wrap(ctx, err, "checking script content usage for cleanup")
}
if usageCount == 0 {
// Not being used, safe to delete
deleteStmt := `DELETE FROM script_contents WHERE id = ?`
_, err = tx.ExecContext(ctx, deleteStmt, contentID)
if err != nil {
return ctxerr.Wrap(ctx, err, "deleting unused script content")
}
}
return nil
}
func (ds *Datastore) Script(ctx context.Context, id uint) (*fleet.Script, error) {
return ds.getScriptDB(ctx, ds.reader(ctx), id)
}
func (ds *Datastore) getScriptDB(ctx context.Context, q sqlx.QueryerContext, id uint) (*fleet.Script, error) {
const getStmt = `
SELECT
id,
team_id,
name,
created_at,
updated_at,
script_content_id
FROM
scripts
WHERE
id = ?
`
var script fleet.Script
if err := sqlx.GetContext(ctx, q, &script, getStmt, id); err != nil {
if err == sql.ErrNoRows {
return nil, notFound("Script").WithID(id)
}
return nil, ctxerr.Wrap(ctx, err, "get script")
}
return &script, nil
}
func (ds *Datastore) GetScriptContents(ctx context.Context, id uint) ([]byte, error) {
const getStmt = `
SELECT
sc.contents
FROM
script_contents sc
JOIN scripts s ON s.script_content_id = sc.id
WHERE
s.id = ?
`
var contents []byte
if err := sqlx.GetContext(ctx, ds.reader(ctx), &contents, getStmt, id); err != nil {
if err == sql.ErrNoRows {
return nil, notFound("Script").WithID(id)
}
return nil, ctxerr.Wrap(ctx, err, "get script contents")
}
return contents, nil
}
func (ds *Datastore) GetAnyScriptContents(ctx context.Context, id uint) ([]byte, error) {
const getStmt = `
SELECT
sc.contents
FROM
script_contents sc
WHERE
sc.id = ?
`
var contents []byte
if err := sqlx.GetContext(ctx, ds.reader(ctx), &contents, getStmt, id); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, notFound("Script").WithID(id)
}
return nil, ctxerr.Wrap(ctx, err, "get any script contents")
}
return contents, nil
}
var errDeleteScriptWithAssociatedPolicy = &fleet.ConflictError{Message: "Couldn't delete. Policy automation uses this script. Please remove this script from associated policy automations and try again."}
func (ds *Datastore) DeleteScript(ctx context.Context, id uint) error {
var activateAffectedHosts []uint
err := ds.withTx(ctx, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, `DELETE FROM host_script_results WHERE script_id = ?
AND exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)`,
id, int(constants.MaxServerWaitTime.Seconds()),
)
if err != nil {
return ctxerr.Wrapf(ctx, err, "cancel pending script executions")
}
// load hosts that will have their upcoming_activities deleted, if that
// activity is "activated", as that means we will have to call
// activateNextUpcomingActivity for those hosts.
loadAffectedHostsStmt := `
SELECT
DISTINCT host_id
FROM
upcoming_activities ua
INNER JOIN script_upcoming_activities sua
ON ua.id = sua.upcoming_activity_id
WHERE sua.script_id = ? AND
ua.activity_type = 'script' AND
ua.activated_at IS NOT NULL AND
(ua.payload->'$.sync_request' = 0 OR
ua.created_at >= NOW() - INTERVAL ? SECOND)`
var affectedHosts []uint
if err := sqlx.SelectContext(ctx, tx, &affectedHosts, loadAffectedHostsStmt,
id, int(constants.MaxServerWaitTime.Seconds())); err != nil {
return ctxerr.Wrapf(ctx, err, "load affected hosts")
}
activateAffectedHosts = affectedHosts
_, err = tx.ExecContext(ctx, `DELETE FROM upcoming_activities
USING upcoming_activities
INNER JOIN script_upcoming_activities sua
ON upcoming_activities.id = sua.upcoming_activity_id
WHERE sua.script_id = ? AND
upcoming_activities.activity_type = 'script' AND
(upcoming_activities.payload->'$.sync_request' = 0 OR
upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND)
`,
id, int(constants.MaxServerWaitTime.Seconds()),
)
if err != nil {
return ctxerr.Wrapf(ctx, err, "cancel upcoming pending script executions")
}
_, err = tx.ExecContext(ctx, `DELETE FROM scripts WHERE id = ?`, id)
if err != nil {
if isMySQLForeignKey(err) {
// Check if the script is referenced by a policy automation.
var count int
if err := sqlx.GetContext(ctx, tx, &count, `SELECT COUNT(*) FROM policies WHERE script_id = ?`, id); err != nil {
return ctxerr.Wrapf(ctx, err, "getting reference from policies")
}
if count > 0 {
return ctxerr.Wrap(ctx, errDeleteScriptWithAssociatedPolicy, "delete script")
}
}
return ctxerr.Wrap(ctx, err, "delete script")
}
return nil
})
if err != nil {
return err
}
// we call this outside of the transaction to avoid a
// long-running/deadlock-prone transaction, as many hosts could be affected.
return ds.activateNextUpcomingActivityForBatchOfHosts(ctx, activateAffectedHosts)
}
// deletePendingHostScriptExecutionsForPolicy should be called when a policy is deleted to remove any pending script executions
func (ds *Datastore) deletePendingHostScriptExecutionsForPolicy(ctx context.Context, teamID *uint, policyID uint) error {
var globalOrTeamID uint
if teamID != nil {
globalOrTeamID = *teamID
}
deletePendingFunc := func(stmt string, args ...any) error {
_, err := ds.writer(ctx).ExecContext(ctx, stmt, args...)
return ctxerr.Wrap(ctx, err, "delete pending host script executions for policy")
}
deleteHSRStmt := `
DELETE FROM
host_script_results
WHERE
policy_id = ? AND
script_id IN (
SELECT id FROM scripts WHERE scripts.global_or_team_id = ?
) AND
exit_code IS NULL
`
if err := deletePendingFunc(deleteHSRStmt, policyID, globalOrTeamID); err != nil {
return err
}
loadAffectedHostsStmt := `
SELECT
DISTINCT host_id
FROM
upcoming_activities ua
INNER JOIN script_upcoming_activities sua
ON ua.id = sua.upcoming_activity_id
WHERE
ua.activity_type = 'script' AND
ua.activated_at IS NOT NULL AND
sua.policy_id = ? AND
sua.script_id IN (
SELECT id FROM scripts WHERE scripts.global_or_team_id = ?
)`
var affectedHosts []uint
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &affectedHosts,
loadAffectedHostsStmt, policyID, globalOrTeamID); err != nil {
return err
}
deleteUAStmt := `
DELETE FROM
upcoming_activities
USING
upcoming_activities
INNER JOIN script_upcoming_activities sua
ON upcoming_activities.id = sua.upcoming_activity_id
WHERE
upcoming_activities.activity_type = 'script' AND
sua.policy_id = ? AND
sua.script_id IN (
SELECT id FROM scripts WHERE scripts.global_or_team_id = ?
)
`
if err := deletePendingFunc(deleteUAStmt, policyID, globalOrTeamID); err != nil {
return err
}
return ds.activateNextUpcomingActivityForBatchOfHosts(ctx, affectedHosts)
}
func (ds *Datastore) ListScripts(ctx context.Context, teamID *uint, opt fleet.ListOptions) ([]*fleet.Script, *fleet.PaginationMetadata, error) {
var scripts []*fleet.Script
const selectStmt = `
SELECT
s.id,
s.team_id,
s.name,
s.created_at,
s.updated_at
FROM
scripts s
WHERE
s.global_or_team_id = ?
`
var globalOrTeamID uint
if teamID != nil {
globalOrTeamID = *teamID
}
args := []any{globalOrTeamID}
stmt, args := appendListOptionsWithCursorToSQL(selectStmt, args, &opt)
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &scripts, stmt, args...); err != nil {
return nil, nil, ctxerr.Wrap(ctx, err, "select scripts")
}
var metaData *fleet.PaginationMetadata
if opt.IncludeMetadata {
metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0}
if len(scripts) > int(opt.PerPage) { //nolint:gosec // dismiss G115
metaData.HasNextResults = true
scripts = scripts[:len(scripts)-1]
}
}
return scripts, metaData, nil
}
func (ds *Datastore) GetScriptIDByName(ctx context.Context, name string, teamID *uint) (uint, error) {
const selectStmt = `
SELECT
id
FROM
scripts
WHERE
global_or_team_id = ?
AND name = ?
`
var globalOrTeamID uint
if teamID != nil {
globalOrTeamID = *teamID
}
var id uint
if err := sqlx.GetContext(ctx, ds.reader(ctx), &id, selectStmt, globalOrTeamID, name); err != nil {
if err == sql.ErrNoRows {
return 0, notFound("Script").WithName(name)
}
return 0, ctxerr.Wrap(ctx, err, "get script by name")
}
return id, nil
}
func (ds *Datastore) GetHostScriptDetails(ctx context.Context, hostID uint, teamID *uint, opt fleet.ListOptions, hostPlatform string) ([]*fleet.HostScriptDetail, *fleet.PaginationMetadata, error) {
var globalOrTeamID uint
if teamID != nil {
globalOrTeamID = *teamID
}
var extension string
switch {
case hostPlatform == "windows":
// filter by .ps1 extension
extension = `%.ps1`
case fleet.IsUnixLike(hostPlatform):
// filter by .sh extension
extension = `%.sh`
default:
// no extension filter
}
type row struct {
ScriptID uint `db:"script_id"`
Name string `db:"name"`
HSRID *uint `db:"hsr_id"`
ExecutionID *string `db:"execution_id"`
ExecutedAt *time.Time `db:"executed_at"`
ExitCode *int64 `db:"exit_code"`
}
sql := `
WITH all_latest_activities AS (
-- Use window function to efficiently find the latest execution per script
-- This is O(n) (a self-join approach would be O(n²))
SELECT * FROM (
SELECT
id,
host_id,
script_id,
execution_id,
created_at,
exit_code,
'completed' as source,
ROW_NUMBER() OVER (
PARTITION BY script_id
ORDER BY created_at DESC, id DESC
) AS row_num
FROM
host_script_results
WHERE
host_id = ? AND
canceled = 0
) completed_ranked
WHERE row_num = 1
UNION ALL
-- latest from upcoming_activities
SELECT * FROM (
SELECT
NULL as id,
ua.host_id,
sua.script_id,
ua.execution_id,
ua.created_at,
NULL as exit_code,
'upcoming' as source,
ROW_NUMBER() OVER (
PARTITION BY sua.script_id
ORDER BY ua.created_at DESC, ua.id DESC
) AS row_num
FROM
upcoming_activities ua
INNER JOIN script_upcoming_activities sua
ON ua.id = sua.upcoming_activity_id
WHERE
ua.host_id = ? AND
ua.activity_type = 'script'
) upcoming_ranked
WHERE row_num = 1
)
SELECT
s.id AS script_id,
s.name,
latest.id AS hsr_id,
latest.created_at AS executed_at,
latest.execution_id,
latest.exit_code
FROM
scripts s
LEFT JOIN (
-- Pick the most recent between completed and upcoming for each script
SELECT * FROM (
SELECT
*,
ROW_NUMBER() OVER (
PARTITION BY script_id
ORDER BY
CASE WHEN source = 'upcoming' THEN 1 ELSE 2 END, -- Prefer upcoming over completed
created_at DESC,
id DESC
) AS final_rn
FROM all_latest_activities
) final_ranked
WHERE final_rn = 1
) latest
ON s.id = latest.script_id
WHERE
(latest.host_id IS NULL OR latest.host_id = ?)
AND s.global_or_team_id = ?
`
args := []any{hostID, hostID, hostID, globalOrTeamID}
if len(extension) > 0 {
args = append(args, extension)
sql += `
AND s.name LIKE ?
`
}
stmt, args := appendListOptionsWithCursorToSQL(sql, args, &opt)
var rows []*row
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &rows, stmt, args...); err != nil {
return nil, nil, ctxerr.Wrap(ctx, err, "get host script details")
}
var metaData *fleet.PaginationMetadata
if opt.IncludeMetadata {
metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0}
if len(rows) > int(opt.PerPage) { //nolint:gosec // dismiss G115
metaData.HasNextResults = true
rows = rows[:len(rows)-1]
}
}
results := make([]*fleet.HostScriptDetail, 0, len(rows))
for _, r := range rows {
results = append(results, fleet.NewHostScriptDetail(hostID, r.ScriptID, r.Name, r.ExecutionID, r.ExecutedAt, r.ExitCode, r.HSRID))
}
return results, metaData, nil
}
func (ds *Datastore) BatchSetScripts(ctx context.Context, tmID *uint, scripts []*fleet.Script) ([]fleet.ScriptResponse, error) {
const loadExistingScripts = `
SELECT
name
FROM
scripts
WHERE
global_or_team_id = ? AND
name IN (?)
`
const deleteAllScriptsInTeam = `
DELETE FROM
scripts
WHERE
global_or_team_id = ?
`
const unsetAllScriptsFromPolicies = `UPDATE policies SET script_id = NULL WHERE team_id = ?`
const clearAllPendingExecutionsHSR = `DELETE FROM host_script_results WHERE
exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)
AND script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)`
const loadAffectedHostsAllPendingExecutionsUA = `
SELECT
DISTINCT host_id
FROM
upcoming_activities ua
INNER JOIN script_upcoming_activities sua
ON ua.id = sua.upcoming_activity_id
WHERE
ua.activity_type = 'script'
AND ua.activated_at IS NOT NULL
AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND)
AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)`
const clearAllPendingExecutionsUA = `DELETE FROM upcoming_activities
USING
upcoming_activities
INNER JOIN script_upcoming_activities sua
ON upcoming_activities.id = sua.upcoming_activity_id
WHERE
upcoming_activities.activity_type = 'script'
AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND)
AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)`
const unsetScriptsNotInListFromPolicies = `
UPDATE policies SET script_id = NULL
WHERE script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))
`
const deleteScriptsNotInList = `
DELETE FROM
scripts
WHERE
global_or_team_id = ? AND
name NOT IN (?)
`
const clearPendingExecutionsNotInListHSR = `DELETE FROM host_script_results WHERE
exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)
AND script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))`
const loadAffectedHostsPendingExecutionsNotInListUA = `
SELECT
DISTINCT host_id
FROM
upcoming_activities ua
INNER JOIN script_upcoming_activities sua
ON ua.id = sua.upcoming_activity_id
WHERE
ua.activity_type = 'script'
AND ua.activated_at IS NOT NULL
AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND)
AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))`
const clearPendingExecutionsNotInListUA = `DELETE FROM upcoming_activities
USING
upcoming_activities
INNER JOIN script_upcoming_activities sua
ON upcoming_activities.id = sua.upcoming_activity_id
WHERE
upcoming_activities.activity_type = 'script'
AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND)
AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))`
const insertNewOrEditedScript = `
INSERT INTO
scripts (
team_id, global_or_team_id, name, script_content_id
)
VALUES
(?, ?, ?, ?)
ON DUPLICATE KEY UPDATE
script_content_id = VALUES(script_content_id), id=LAST_INSERT_ID(id)
`
const clearPendingExecutionsWithObsoleteScriptHSR = `DELETE FROM host_script_results WHERE
exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)
AND script_id = ? AND script_content_id != ?`
const loadAffectedHostsPendingExecutionsWithObsoleteScriptUA = `
SELECT
DISTINCT host_id
FROM
upcoming_activities ua
INNER JOIN script_upcoming_activities sua
ON ua.id = sua.upcoming_activity_id
WHERE
ua.activity_type = 'script'
AND ua.activated_at IS NOT NULL
AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND)
AND sua.script_id = ? AND sua.script_content_id != ?`
const clearPendingExecutionsWithObsoleteScriptUA = `DELETE FROM upcoming_activities
USING
upcoming_activities
INNER JOIN script_upcoming_activities sua
ON upcoming_activities.id = sua.upcoming_activity_id
WHERE
upcoming_activities.activity_type = 'script'
AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND)
AND sua.script_id = ? AND sua.script_content_id != ?`
const loadInsertedScripts = `SELECT id, team_id, name FROM scripts WHERE global_or_team_id = ?`
// use a team id of 0 if no-team
var globalOrTeamID uint
if tmID != nil {
globalOrTeamID = *tmID
}
// build a list of names for the incoming scripts, will keep the
// existing ones if there's a match and no change
incomingNames := make([]string, len(scripts))
// at the same time, index the incoming scripts keyed by name for ease
// of processing
incomingScripts := make(map[string]*fleet.Script, len(scripts))
for i, p := range scripts {
incomingNames[i] = p.Name
incomingScripts[p.Name] = p
}
var insertedScripts []fleet.ScriptResponse
var activateAffectedHosts []uint
if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
var existingScripts []*fleet.Script
if len(incomingNames) > 0 {
// load existing scripts that match the incoming scripts by names
stmt, args, err := sqlx.In(loadExistingScripts, globalOrTeamID, incomingNames)
if err != nil {
return ctxerr.Wrap(ctx, err, "build query to load existing scripts")
}
if err := sqlx.SelectContext(ctx, tx, &existingScripts, stmt, args...); err != nil {
return ctxerr.Wrap(ctx, err, "load existing scripts")
}
}
// figure out if we need to delete any scripts
keepNames := make([]string, 0, len(incomingNames))
for _, p := range existingScripts {
if newS := incomingScripts[p.Name]; newS != nil {
keepNames = append(keepNames, p.Name)
}
}
var (
scriptsStmt string
scriptsArgs []any
policiesStmt string
policiesArgs []any
executionsStmt string
executionsArgs []any
extraExecStmt string
extraExecArgs []any
err error
affectedHostIDs []uint
)
if len(keepNames) > 0 {
// delete the obsolete scripts
scriptsStmt, scriptsArgs, err = sqlx.In(deleteScriptsNotInList, globalOrTeamID, keepNames)
if err != nil {
return ctxerr.Wrap(ctx, err, "build statement to delete obsolete scripts")
}
policiesStmt, policiesArgs, err = sqlx.In(unsetScriptsNotInListFromPolicies, globalOrTeamID, keepNames)
if err != nil {
return ctxerr.Wrap(ctx, err, "build statement to unset obsolete scripts from policies")
}
executionsStmt, executionsArgs, err = sqlx.In(clearPendingExecutionsNotInListHSR, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames)
if err != nil {
return ctxerr.Wrap(ctx, err, "build statement to clear pending script executions from obsolete scripts")
}
loadAffectedStmt, args, err := sqlx.In(loadAffectedHostsPendingExecutionsNotInListUA,
int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames)
if err != nil {
return ctxerr.Wrap(ctx, err, "build query to load affected hosts for upcoming script executions")
}
if err := sqlx.SelectContext(ctx, tx, &affectedHostIDs, loadAffectedStmt, args...); err != nil {
return ctxerr.Wrap(ctx, err, "load affected hosts for upcoming script executions")
}
extraExecStmt, extraExecArgs, err = sqlx.In(clearPendingExecutionsNotInListUA, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames)
if err != nil {
return ctxerr.Wrap(ctx, err, "build statement to clear upcoming pending script executions from obsolete scripts")
}
} else {
scriptsStmt = deleteAllScriptsInTeam
scriptsArgs = []any{globalOrTeamID}
policiesStmt = unsetAllScriptsFromPolicies
policiesArgs = []any{globalOrTeamID}
executionsStmt = clearAllPendingExecutionsHSR
executionsArgs = []any{int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID}
if err := sqlx.SelectContext(ctx, tx, &affectedHostIDs,
loadAffectedHostsAllPendingExecutionsUA, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID); err != nil {
return ctxerr.Wrap(ctx, err, "load affected hosts for upcoming script executions")
}
extraExecStmt = clearAllPendingExecutionsUA
extraExecArgs = []any{int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID}
}
if _, err := tx.ExecContext(ctx, policiesStmt, policiesArgs...); err != nil {
return ctxerr.Wrap(ctx, err, "unset obsolete scripts from policies")
}
if _, err := tx.ExecContext(ctx, executionsStmt, executionsArgs...); err != nil {
return ctxerr.Wrap(ctx, err, "clear obsolete script pending executions")
}
if _, err := tx.ExecContext(ctx, extraExecStmt, extraExecArgs...); err != nil {
return ctxerr.Wrap(ctx, err, "clear obsolete upcoming script pending executions")
}
if _, err := tx.ExecContext(ctx, scriptsStmt, scriptsArgs...); err != nil {
return ctxerr.Wrap(ctx, err, "delete obsolete scripts")
}
activateAffectedHosts = affectedHostIDs
// insert the new scripts and the ones that have changed
for _, s := range incomingScripts {
scRes, err := insertScriptContents(ctx, tx, s.ScriptContents)
if err != nil {
return ctxerr.Wrapf(ctx, err, "inserting script contents for script with name %q", s.Name)
}
contentID, _ := scRes.LastInsertId()
insertRes, err := tx.ExecContext(ctx, insertNewOrEditedScript, tmID, globalOrTeamID, s.Name, uint(contentID)) //nolint:gosec // dismiss G115
if err != nil {
return ctxerr.Wrapf(ctx, err, "insert new/edited script with name %q", s.Name)
}
scriptID, _ := insertRes.LastInsertId()
if _, err := tx.ExecContext(ctx, clearPendingExecutionsWithObsoleteScriptHSR, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil {
return ctxerr.Wrapf(ctx, err, "clear obsolete pending script executions with name %q", s.Name)
}
var affectedHosts []uint
if err := sqlx.SelectContext(ctx, tx, &affectedHosts, loadAffectedHostsPendingExecutionsWithObsoleteScriptUA,
int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil {
return ctxerr.Wrapf(ctx, err, "load affected hosts for upcoming script executions with name %q", s.Name)
}
activateAffectedHosts = append(activateAffectedHosts, affectedHosts...)
if _, err = tx.ExecContext(ctx, clearPendingExecutionsWithObsoleteScriptUA, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil {
return ctxerr.Wrapf(ctx, err, "clear obsolete upcoming pending script executions with name %q", s.Name)
}
}
if err := sqlx.SelectContext(ctx, tx, &insertedScripts, loadInsertedScripts, globalOrTeamID); err != nil {
return ctxerr.Wrap(ctx, err, "load inserted scripts")
}
return nil
}); err != nil {
return nil, err
}
if err := ds.activateNextUpcomingActivityForBatchOfHosts(ctx, activateAffectedHosts); err != nil {
return nil, ctxerr.Wrap(ctx, err, "activate next upcoming activity for batch of hosts")
}
return insertedScripts, nil
}
func (ds *Datastore) GetHostLockWipeStatus(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
const stmt = `
SELECT
lock_ref,
wipe_ref,
unlock_ref,
unlock_pin,
fleet_platform
FROM
host_mdm_actions
WHERE
host_id = ?
`
var mdmActions struct {
LockRef *string `db:"lock_ref"`
WipeRef *string `db:"wipe_ref"`
UnlockRef *string `db:"unlock_ref"`
UnlockPIN *string `db:"unlock_pin"`
FleetPlatform string `db:"fleet_platform"`
}
fleetPlatform := host.FleetPlatform()
status := &fleet.HostLockWipeStatus{
HostFleetPlatform: fleetPlatform,
}
if err := sqlx.GetContext(ctx, ds.reader(ctx), &mdmActions, stmt, host.ID); err != nil {
if err == sql.ErrNoRows {
// do not return a Not Found error, return the zero-value status, which
// will report the correct states.
return status, nil
}
return nil, ctxerr.Wrap(ctx, err, "get host lock/wipe status")
}
// if we have a fleet platform stored in host_mdm_actions, use it instead of
// the host.FleetPlatform() because the platform can be overwritten with an
// unknown OS name when a Wipe gets executed.
if mdmActions.FleetPlatform != "" {
fleetPlatform = mdmActions.FleetPlatform
status.HostFleetPlatform = fleetPlatform
}
switch fleetPlatform {
case "darwin", "ios", "ipados":
if mdmActions.UnlockPIN != nil {
status.UnlockPIN = *mdmActions.UnlockPIN
}
if mdmActions.UnlockRef != nil {
var err error
status.UnlockRequestedAt, err = time.Parse(time.DateTime, *mdmActions.UnlockRef)
if err != nil {
// if the format is unexpected but there's something in UnlockRef, just
// replace it with the current timestamp, it should still indicate that
// an unlock was requested (e.g. in case someone plays with the data
// directly in the DB and messes up the format).
status.UnlockRequestedAt = time.Now().UTC()
}
}
if mdmActions.LockRef != nil {
// the lock reference is an MDM command
cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.LockRef, host.UUID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get lock reference")
}
status.LockMDMCommand = cmd
status.LockMDMCommandResult = cmdRes
}
if mdmActions.WipeRef != nil {
// the wipe reference is an MDM command
cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.WipeRef, host.UUID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get wipe reference")
}
status.WipeMDMCommand = cmd
status.WipeMDMCommandResult = cmdRes
}
case "windows", "linux":
// lock and unlock references are scripts
if mdmActions.LockRef != nil {
hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.LockRef, scriptExecutionSearchOpts{IncludeCanceled: true})
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get lock reference script result")
}
status.LockScript = hsr
}
if mdmActions.UnlockRef != nil {
hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.UnlockRef, scriptExecutionSearchOpts{IncludeCanceled: true})
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get unlock reference script result")
}
status.UnlockScript = hsr
}
// wipe is an MDM command on Windows, a script on Linux
if mdmActions.WipeRef != nil {
if fleetPlatform == "windows" {
cmd, cmdRes, err := ds.getHostMDMWindowsCommand(ctx, *mdmActions.WipeRef, host.UUID)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get wipe reference")
}
status.WipeMDMCommand = cmd
status.WipeMDMCommandResult = cmdRes
} else {
hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.WipeRef, scriptExecutionSearchOpts{IncludeCanceled: true})
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "get wipe reference script result")
}
status.WipeScript = hsr
}
}
}
return status, nil
}
func (ds *Datastore) getHostMDMWindowsCommand(ctx context.Context, cmdUUID, hostUUID string) (*fleet.MDMCommand, *fleet.MDMCommandResult, error) {
cmd, err := ds.getMDMCommand(ctx, ds.reader(ctx), cmdUUID)
if err != nil {
return nil, nil, ctxerr.Wrap(ctx, err, "get Windows MDM command")
}
// get the MDM command result, which may be not found (indicating the command
// is pending). Note that it doesn't return ErrNoRows if not found, it
// returns success and an empty cmdRes slice.
cmdResults, err := ds.GetMDMWindowsCommandResults(ctx, cmdUUID)
if err != nil {
return nil, nil, ctxerr.Wrap(ctx, err, "get Windows MDM command result")
}
// each item in the slice returned by GetMDMWindowsCommandResults is
// potentially a result for a different host, we need to find the one for
// that specific host.
var cmdRes *fleet.MDMCommandResult
for _, r := range cmdResults {
if r.HostUUID != hostUUID {
continue
}
// all statuses for Windows indicate end of processing of the command
// (there is no equivalent of "NotNow" or "Idle" as for Apple).
cmdRes = r
break
}
return cmd, cmdRes, nil
}
func (ds *Datastore) getHostMDMAppleCommand(ctx context.Context, cmdUUID, hostUUID string) (*fleet.MDMCommand, *fleet.MDMCommandResult, error) {
cmd, err := ds.getMDMCommand(ctx, ds.reader(ctx), cmdUUID)
if err != nil {
return nil, nil, ctxerr.Wrap(ctx, err, "get Apple MDM command")
}
// get the MDM command result, which may be not found (indicating the command
// is pending). Note that it doesn't return ErrNoRows if not found, it
// returns success and an empty cmdRes slice.
cmdResults, err := ds.GetMDMAppleCommandResults(ctx, cmdUUID)
if err != nil {
return nil, nil, ctxerr.Wrap(ctx, err, "get Apple MDM command result")
}
// each item in the slice returned by GetMDMAppleCommandResults is
// potentially a result for a different host, we need to find the one for
// that specific host.
var cmdRes *fleet.MDMCommandResult
for _, r := range cmdResults {
if r.HostUUID != hostUUID {
continue
}
if r.Status == fleet.MDMAppleStatusAcknowledged || r.Status == fleet.MDMAppleStatusError || r.Status == fleet.MDMAppleStatusCommandFormatError {
cmdRes = r
break
}
}
return cmd, cmdRes, nil
}
// LockHostViaScript will create the script execution request and update
// host_mdm_actions in a single transaction.
func (ds *Datastore) LockHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error {
var res *fleet.HostScriptResult
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
var err error
scRes, err := insertScriptContents(ctx, tx, request.ScriptContents)
if err != nil {
return err
}
id, _ := scRes.LastInsertId()
request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115
res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true)
if err != nil {
return ctxerr.Wrap(ctx, err, "lock host via script create execution")
}
// on duplicate we don't clear any other existing state because at this
// point in time, this is just a request to lock the host that is recorded,
// it is pending execution. The host's state should be updated to "locked"
// only when the script execution is successfully completed, and then any
// unlock or wipe references should be cleared.
const stmt = `
INSERT INTO host_mdm_actions
(
host_id,
lock_ref,
fleet_platform
)
VALUES (?,?,?)
ON DUPLICATE KEY UPDATE
lock_ref = VALUES(lock_ref)
`
_, err = tx.ExecContext(ctx, stmt,
request.HostID,
res.ExecutionID,
hostFleetPlatform,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "lock host via script update mdm actions")
}
return nil
})
}
// UnlockHostViaScript will create the script execution request and update
// host_mdm_actions in a single transaction.
func (ds *Datastore) UnlockHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error {
var res *fleet.HostScriptResult
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
var err error
scRes, err := insertScriptContents(ctx, tx, request.ScriptContents)
if err != nil {
return err
}
id, _ := scRes.LastInsertId()
request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115
res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true)
if err != nil {
return ctxerr.Wrap(ctx, err, "unlock host via script create execution")
}
// on duplicate we don't clear any other existing state because at this
// point in time, this is just a request to unlock the host that is
// recorded, it is pending execution. The host's state should be updated to
// "unlocked" only when the script execution is successfully completed, and
// then any lock or wipe references should be cleared.
const stmt = `
INSERT INTO host_mdm_actions
(
host_id,
unlock_ref,
fleet_platform
)
VALUES (?,?,?)
ON DUPLICATE KEY UPDATE
unlock_ref = VALUES(unlock_ref),
unlock_pin = NULL
`
_, err = tx.ExecContext(ctx, stmt,
request.HostID,
res.ExecutionID,
hostFleetPlatform,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "unlock host via script update mdm actions")
}
return err
})
}
// WipeHostViaScript creates the script execution request and updates the
// host_mdm_actions table in a single transaction.
func (ds *Datastore) WipeHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error {
var res *fleet.HostScriptResult
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
var err error
scRes, err := insertScriptContents(ctx, tx, request.ScriptContents)
if err != nil {
return err
}
id, _ := scRes.LastInsertId()
request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115
res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true)
if err != nil {
return ctxerr.Wrap(ctx, err, "wipe host via script create execution")
}
// on duplicate we don't clear any other existing state because at this
// point in time, this is just a request to wipe the host that is recorded,
// it is pending execution, so if it was locked, it is still locked (so the
// lock_ref info must still be there).
const stmt = `
INSERT INTO host_mdm_actions
(
host_id,
wipe_ref,
fleet_platform
)
VALUES (?,?,?)
ON DUPLICATE KEY UPDATE
wipe_ref = VALUES(wipe_ref)
`
_, err = tx.ExecContext(ctx, stmt,
request.HostID,
res.ExecutionID,
hostFleetPlatform,
)
if err != nil {
return ctxerr.Wrap(ctx, err, "wipe host via script update mdm actions")
}
return err
})
}
func (ds *Datastore) UnlockHostManually(ctx context.Context, hostID uint, hostFleetPlatform string, ts time.Time) error {
const stmt = `
INSERT INTO host_mdm_actions
(
host_id,
unlock_ref,
fleet_platform
)
VALUES (?, ?, ?)
ON DUPLICATE KEY UPDATE
-- do not overwrite if a value is already set
unlock_ref = IF(unlock_ref IS NULL, VALUES(unlock_ref), unlock_ref)
`
// for macOS, the unlock_ref is just the timestamp at which the user first
// requested to unlock the host. This then indicates in the host's status
// that it's pending an unlock (which requires manual intervention by
// entering a PIN on the device). The /unlock endpoint can be called multiple
// times, so we record the timestamp of the first time it was requested and
// from then on, the host is marked as "pending unlock" until the device is
// actually unlocked with the PIN. The actual unlocking happens when the
// device sends an Idle MDM request.
unlockRef := ts.Format(time.DateTime)
_, err := ds.writer(ctx).ExecContext(ctx, stmt, hostID, unlockRef, hostFleetPlatform)
return ctxerr.Wrap(ctx, err, "record manual unlock host request")
}
func buildHostLockWipeStatusUpdateStmt(refCol string, succeeded bool, joinPart string, setUnlockRef bool) string {
var alias string
stmt := `UPDATE host_mdm_actions `
if joinPart != "" {
stmt += `hma ` + joinPart
alias = "hma."
}
stmt += ` SET `
if succeeded {
switch refCol {
case "lock_ref":
// Note that this must not clear the unlock_pin, because recording the
// lock request does generate the PIN and store it there to be used by an
// eventual unlock.
if !setUnlockRef {
stmt += fmt.Sprintf("%sunlock_ref = NULL, %[1]swipe_ref = NULL", alias)
} else {
// Currently only used for Apple MDM devices.
// We set the unlock_ref to current time since the device can be unlocked any time after the lock.
// Apple MDM does not have a concept of unlock pending.
stmt += fmt.Sprintf("%sunlock_ref = '%s', %[1]swipe_ref = NULL", alias, time.Now().Format(time.DateTime))
}
case "unlock_ref":
// a successful unlock clears itself as well as the lock ref, because
// unlock is the default state so we don't need to keep its unlock_ref
// around once it's confirmed.
stmt += fmt.Sprintf("%slock_ref = NULL, %[1]sunlock_ref = NULL, %[1]sunlock_pin = NULL, %[1]swipe_ref = NULL", alias)
case "wipe_ref":
stmt += fmt.Sprintf("%slock_ref = NULL, %[1]sunlock_ref = NULL, %[1]sunlock_pin = NULL", alias)
}
} else {
// if the action failed, then we clear the reference to that action itself so
// the host stays in the previous state (it doesn't transition to the new
// state).
stmt += fmt.Sprintf("%s"+refCol+" = NULL", alias)
}
return stmt
}
func (ds *Datastore) UpdateHostLockWipeStatusFromAppleMDMResult(ctx context.Context, hostUUID, cmdUUID, requestType string, succeeded bool) error {
// a bit of MDM protocol leaking in the mysql layer, but it's either that or
// the other way around (MDM protocol would translate to database column)
var refCol string
var setUnlockRef bool
switch requestType {
case "EraseDevice":
refCol = "wipe_ref"
case "DeviceLock":
refCol = "lock_ref"
setUnlockRef = true
default:
return nil
}
return updateHostLockWipeStatusFromResultAndHostUUID(ctx, ds.writer(ctx), hostUUID, refCol, cmdUUID, succeeded, setUnlockRef)
}
func updateHostLockWipeStatusFromResultAndHostUUID(
ctx context.Context, tx sqlx.ExtContext, hostUUID, refCol, cmdUUID string, succeeded bool, setUnlockRef bool,
) error {
stmt := buildHostLockWipeStatusUpdateStmt(refCol, succeeded, `JOIN hosts h ON hma.host_id = h.id`, setUnlockRef)
stmt += ` WHERE h.uuid = ? AND hma.` + refCol + ` = ?`
_, err := tx.ExecContext(ctx, stmt, hostUUID, cmdUUID)
return ctxerr.Wrap(ctx, err, "update host lock/wipe status from result via host uuid")
}
func updateHostLockWipeStatusFromResult(ctx context.Context, tx sqlx.ExtContext, hostID uint, refCol string, succeeded bool) error {
stmt := buildHostLockWipeStatusUpdateStmt(refCol, succeeded, "", false)
stmt += ` WHERE host_id = ?`
_, err := tx.ExecContext(ctx, stmt, hostID)
return ctxerr.Wrap(ctx, err, "update host lock/wipe status from result")
}
func (ds *Datastore) updateUninstallStatusFromResult(ctx context.Context, tx sqlx.ExtContext, hostID uint, executionID string, exitCode int) error {
stmt := `
UPDATE host_software_installs SET uninstall_script_exit_code = ? WHERE execution_id = ? AND host_id = ?
`
if _, err := tx.ExecContext(ctx, stmt, exitCode, executionID, hostID); err != nil {
return ctxerr.Wrap(ctx, err, "update uninstall status from result")
}
// NOTE: no need to call activateNextUpcomingActivity here as this function
// is called from SetHostScriptExecutionResult which will call it before
// completing.
return nil
}
func (ds *Datastore) CleanupUnusedScriptContents(ctx context.Context) error {
deleteStmt := `
DELETE FROM
script_contents
WHERE
NOT EXISTS (
SELECT 1 FROM host_script_results WHERE script_content_id = script_contents.id)
AND NOT EXISTS (
SELECT 1 FROM scripts WHERE script_content_id = script_contents.id)
AND NOT EXISTS (
SELECT 1 FROM software_installers si
WHERE script_contents.id IN (si.install_script_content_id, si.post_install_script_content_id, si.uninstall_script_content_id)
)
AND NOT EXISTS (
SELECT 1 FROM setup_experience_scripts WHERE script_content_id = script_contents.id
)
AND NOT EXISTS (
SELECT 1 FROM script_upcoming_activities WHERE script_content_id = script_contents.id
)
`
_, err := ds.writer(ctx).ExecContext(ctx, deleteStmt)
if err != nil {
return ctxerr.Wrap(ctx, err, "cleaning up unused script contents")
}
return nil
}
func (ds *Datastore) getOrGenerateScriptContentsID(ctx context.Context, contents string) (uint, error) {
csum := md5ChecksumScriptContent(contents)
scriptContentsID, err := ds.optimisticGetOrInsert(ctx,
&parameterizedStmt{
Statement: `SELECT id FROM script_contents WHERE md5_checksum = UNHEX(?)`,
Args: []interface{}{csum},
},
&parameterizedStmt{
Statement: `INSERT INTO script_contents (md5_checksum, contents) VALUES (UNHEX(?), ?)`,
Args: []interface{}{csum, contents},
},
)
if err != nil {
return 0, err
}
return scriptContentsID, nil
}
func teamIDEq(teamID1, teamID2 *uint) bool {
sameTeamNoTeam := teamID1 == nil && teamID2 == nil
sameTeamNumber := teamID1 != nil && teamID2 != nil && *teamID1 == *teamID2
return sameTeamNoTeam || sameTeamNumber
}
func (ds *Datastore) batchExecuteScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint, batchExecID string) error {
script, err := ds.Script(ctx, scriptID)
if err != nil {
return fleet.NewInvalidArgumentError("script_id", err.Error())
}
invalidHostIDPlatform := "batch-invalid-hostid"
// We need full host info to check if hosts are able to run scripts, see svc.RunHostScript
fullHosts := make([]*fleet.Host, 0, len(hostIDs))
// The execution results to be stored in the database
executions := make([]fleet.BatchExecutionHost, 0, len(fullHosts))
// Check that all hosts exist before attempting to process them
for _, hostID := range hostIDs {
host, err := ds.Host(ctx, hostID)
if err != nil {
fullHosts = append(fullHosts, &fleet.Host{
ID: hostID,
Platform: invalidHostIDPlatform,
})
continue
}
fullHosts = append(fullHosts, host)
}
if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
for _, host := range fullHosts {
// Host doesn't exist anymore
if host.Platform == invalidHostIDPlatform {
executions = append(executions, fleet.BatchExecutionHost{
HostID: host.ID,
Error: &fleet.BatchExecuteInvalidHost,
})
continue
}
// Non-orbit-enrolled host (iOS, android)
noNodeKey := host.OrbitNodeKey == nil || *host.OrbitNodeKey == ""
// Scripts disabled on host
scriptsDisabled := host.ScriptsEnabled != nil && !*host.ScriptsEnabled
if noNodeKey || scriptsDisabled {
executions = append(executions, fleet.BatchExecutionHost{
HostID: host.ID,
Error: &fleet.BatchExecuteIncompatibleFleetd,
})
continue
}
if !fleet.ValidateScriptPlatform(script.Name, host.Platform) {
executions = append(executions, fleet.BatchExecutionHost{
HostID: host.ID,
Error: &fleet.BatchExecuteIncompatiblePlatform,
})
continue
}
executionID, _, err := ds.insertNewHostScriptExecution(ctx, tx, &fleet.HostScriptRequestPayload{
HostID: host.ID,
UserID: userID,
ScriptID: &script.ID,
ScriptContentID: script.ScriptContentID,
}, false)
if err != nil {
return ctxerr.Wrap(ctx, err, "queueing script for bulk execution")
}
executions = append(executions, fleet.BatchExecutionHost{
HostID: host.ID,
ExecutionID: &executionID,
})
}
_, err := tx.ExecContext(
ctx,
`INSERT INTO batch_activities (execution_id, script_id, status, activity_type, num_targeted, started_at) VALUES (?, ?, ?, ?, ?, NOW())
ON DUPLICATE KEY UPDATE status = VALUES(status), started_at = VALUES(started_at)`,
batchExecID,
script.ID,
fleet.ScheduledBatchExecutionStarted,
fleet.BatchExecutionActivityScript,
len(hostIDs),
)
if err != nil {
return ctxerr.Wrap(ctx, err, "failed to insert new batch execution")
}
args := make([]map[string]any, 0, len(executions))
for _, execHost := range executions {
args = append(args, map[string]any{
"batch_id": batchExecID,
"host_id": execHost.HostID,
"host_execution_id": execHost.ExecutionID,
"error": execHost.Error,
})
}
insertStmt := `
INSERT INTO batch_activity_host_results (
batch_execution_id,
host_id,
host_execution_id,
error
) VALUES (
:batch_id,
:host_id,
:host_execution_id,
:error
) ON DUPLICATE KEY UPDATE host_execution_id = VALUES(host_execution_id), error = VALUES(error)`
if _, err := sqlx.NamedExecContext(ctx, tx, insertStmt, args); err != nil {
return ctxerr.Wrap(ctx, err, "associating script executions with batch job")
}
return nil
}); err != nil {
return fmt.Errorf("creating bulk execution order: %w", err)
}
return nil
}
func (ds *Datastore) BatchExecuteScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint) (string, error) {
batchExecID := uuid.New().String()
script, err := ds.Script(ctx, scriptID)
if err != nil {
return "", fleet.NewInvalidArgumentError("script_id", err.Error())
}
for _, hostID := range hostIDs {
host, err := ds.HostLite(ctx, hostID)
if err != nil {
return "", fmt.Errorf("unable to load host information for %d: %w", hostID, err)
}
if !teamIDEq(host.TeamID, script.TeamID) {
return "", ctxerr.Errorf(ctx, "all hosts must be on the same team as the script")
}
}
if err := ds.batchExecuteScript(ctx, userID, scriptID, hostIDs, batchExecID); err != nil {
return "", ctxerr.Wrap(ctx, err, "immediate batch execution")
}
return batchExecID, nil
}
func (ds *Datastore) BatchScheduleScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint, notBefore time.Time) (string, error) {
batchExecID := uuid.New().String()
const batchActivitiesStmt = `INSERT INTO batch_activities (execution_id, job_id, script_id, user_id, status, activity_type, num_targeted) VALUES (?, ?, ?, ?, ?, ?, ?)`
const batchHostsStmt = `INSERT INTO batch_activity_host_results (batch_execution_id, host_id) VALUES (:exec_id, :host_id)`
argBytes, err := json.Marshal(fleet.BatchActivityScriptJobArgs{
ExecutionID: batchExecID,
})
if err != nil {
return "", ctxerr.Wrap(ctx, err, "encooding job args")
}
if err := ds.withTx(ctx, func(tx sqlx.ExtContext) error {
job, err := ds.NewJob(ctx, &fleet.Job{
Name: fleet.BatchActivityScriptsJobName,
Args: (*json.RawMessage)(&argBytes),
State: fleet.JobStateQueued,
NotBefore: notBefore.UTC(),
})
if err != nil {
return ctxerr.Wrap(ctx, err, "creating new job")
}
_, err = tx.ExecContext(
ctx,
batchActivitiesStmt,
batchExecID,
job.ID,
scriptID,
userID,
fleet.ScheduledBatchExecutionScheduled,
fleet.BatchExecutionActivityScript,
len(hostIDs),
)
if err != nil {
return ctxerr.Wrap(ctx, err, "inserting new batch activity")
}
args := make([]map[string]any, 0, len(hostIDs))
for _, hostID := range hostIDs {
args = append(args, map[string]any{
"exec_id": batchExecID,
"host_id": hostID,
})
}
if _, err := sqlx.NamedExecContext(ctx, tx, batchHostsStmt, args); err != nil {
return ctxerr.Wrap(ctx, err, "inserting batch host results")
}
return nil
}); err != nil {
return "", ctxerr.Wrap(ctx, err, "creating scheduled script execution")
}
return batchExecID, nil
}
func (ds *Datastore) CancelBatchScript(ctx context.Context, executionID string) error {
stmt := `
SELECT
bahr.host_execution_id,
bahr.host_id
FROM
batch_activity_host_results bahr
LEFT JOIN
host_script_results hsr ON bahr.host_execution_id = hsr.execution_id -- I think?
WHERE
bahr.batch_execution_id = ?
AND
hsr.canceled = 0
AND
hsr.exit_code IS NULL
AND
bahr.error IS NULL`
stmtSetCanceled := `
UPDATE
batch_activities ba
SET
finished_at = NOW(),
status = 'finished',
canceled = 1,
num_canceled = (SELECT COUNT(*) FROM batch_activity_host_results WHERE batch_execution_id = ba.execution_id)
WHERE
ba.execution_id = ?`
stmtCanceled := `
UPDATE
batch_activities
SET
canceled = 1
WHERE
execution_id = ?`
activity, err := ds.GetBatchActivity(ctx, executionID)
if err != nil {
return ctxerr.Wrap(ctx, err, "getting batch activity")
}
if activity.Status == fleet.ScheduledBatchExecutionFinished {
return nil
}
if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
// If job worker exists, mark it as complete to stop it from running
if jobID := activity.JobID; jobID != nil {
job, err := ds.GetJob(ctx, *jobID)
if err != nil {
return ctxerr.Wrap(ctx, err, "failed to find job associated with batch activity")
}
job.State = fleet.JobStateSuccess
if _, err := ds.updateJob(ctx, tx, *jobID, job); err != nil {
return ctxerr.Wrap(ctx, err, "updating batch activity job")
}
}
if activity.Status == fleet.ScheduledBatchExecutionStarted {
// If the batch activity has started, we need to cancel anything in progress or queued
toCancel := []struct {
HostExecutionID string `db:"host_execution_id"`
HostID uint `db:"host_id"`
}{}
if err := sqlx.SelectContext(ctx, tx, &toCancel, stmt, executionID); err != nil {
return ctxerr.Wrap(ctx, err, "selecting hosts to cancel")
}
for _, host := range toCancel {
if _, err := ds.cancelHostUpcomingActivity(ctx, tx, host.HostID, host.HostExecutionID); err != nil {
return ctxerr.Wrap(ctx, err, "canceling upcoming activity")
}
}
if _, err := tx.ExecContext(ctx, stmtCanceled, executionID); err != nil {
return ctxerr.Wrap(ctx, err, "setting canceled column")
}
if err := ds.markActivitiesAsCompleted(ctx, tx); err != nil {
return ctxerr.Wrap(ctx, err, "marking job as complete and summarizing counts")
}
} else {
// The batch activity is scheduled, but not started
if _, err := tx.ExecContext(ctx, stmtSetCanceled, executionID); err != nil {
return ctxerr.Wrap(ctx, err, "setting canceled host count")
}
}
return nil
}); err != nil {
return ctxerr.Wrap(ctx, err, "cancel batch script db transaction")
}
return nil
}
func (ds *Datastore) GetBatchActivity(ctx context.Context, executionID string) (*fleet.BatchActivity, error) {
const stmt = `
SELECT
ba.id,
ba.script_id,
s.name as script_name,
ba.execution_id,
ba.user_id,
ba.job_id,
ba.status,
ba.activity_type,
ba.num_targeted,
ba.num_pending,
ba.num_ran,
ba.num_errored,
ba.num_incompatible,
ba.num_canceled,
ba.created_at,
ba.updated_at,
ba.started_at,
ba.finished_at,
ba.canceled
FROM
batch_activities ba
LEFT JOIN
scripts s ON s.id = ba.script_id
WHERE
execution_id = ?`
batchActivity := &fleet.BatchActivity{}
if err := sqlx.GetContext(ctx, ds.reader(ctx), batchActivity, stmt, executionID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "selecting batch activity")
}
return batchActivity, nil
}
func (ds *Datastore) GetBatchActivityHostResults(ctx context.Context, executionID string) ([]*fleet.BatchActivityHostResult, error) {
const stmt = `
SELECT
id,
batch_execution_id,
host_id,
host_execution_id,
error
FROM
batch_activity_host_results
WHERE
batch_execution_id = ?`
results := []*fleet.BatchActivityHostResult{}
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, stmt, executionID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "selecting batch activity host results")
}
return results, nil
}
func (ds *Datastore) RunScheduledBatchActivity(ctx context.Context, executionID string) error {
batchActivity, err := ds.GetBatchActivity(ctx, executionID)
if err != nil {
return ctxerr.Wrap(ctx, err, "getting batch activity")
}
if batchActivity.Status != fleet.ScheduledBatchExecutionScheduled {
return ctxerr.New(ctx, "batch job has already been started")
}
if batchActivity.Canceled {
return ctxerr.New(ctx, "batch job was canceled")
}
if batchActivity.ScriptID == nil {
return ctxerr.New(ctx, "no script ID present in batch activity")
}
script, err := ds.Script(ctx, *batchActivity.ScriptID)
if err != nil {
return ctxerr.Wrap(ctx, err, "could not get script")
}
results, err := ds.GetBatchActivityHostResults(ctx, executionID)
if err != nil {
return ctxerr.Wrap(ctx, err, "getting batch activity host results")
}
hostIDs := []uint{}
for _, result := range results {
hostIDs = append(hostIDs, result.HostID)
}
if err := ds.batchExecuteScript(ctx, batchActivity.UserID, script.ID, hostIDs, batchActivity.BatchExecutionID); err != nil {
return ctxerr.Wrap(ctx, err, "scheduled batch script execution")
}
return nil
}
// Deprecated; will be removed in favor of ListBatchScriptExecutions when the batch script details page is ready.
func (ds *Datastore) BatchExecuteSummary(ctx context.Context, executionID string) (*fleet.BatchActivity, error) {
stmtExecutions := `
SELECT
COUNT(*) as num_targeted,
COUNT(bsehr.error) as num_did_not_run,
COUNT(CASE WHEN hsr.exit_code = 0 THEN 1 END) as num_succeeded,
COUNT(CASE WHEN hsr.exit_code <> 0 THEN 1 END) as num_failed,
COUNT(CASE WHEN hsr.canceled = 1 AND hsr.exit_code IS NULL THEN 1 END) as num_cancelled
FROM
batch_activity_host_results bsehr
LEFT JOIN
host_script_results hsr
ON bsehr.host_execution_id = hsr.execution_id
WHERE
bsehr.batch_execution_id = ?`
stmtScriptDetails := `
SELECT
script_id,
s.name as script_name,
s.team_id as team_id,
bse.created_at as created_at
FROM
batch_activities bse
JOIN
scripts s
ON bse.script_id = s.id
WHERE
bse.execution_id = ?`
var summary fleet.BatchActivity
var temp_summary struct {
NumTargeted uint `db:"num_targeted"`
NumDidNotRun uint `db:"num_did_not_run"`
NumSucceeded uint `db:"num_succeeded"`
NumFailed uint `db:"num_failed"`
NumCancelled uint `db:"num_cancelled"`
}
// Fill out the execution details
if err := sqlx.GetContext(ctx, ds.reader(ctx), &temp_summary, stmtExecutions, executionID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary")
}
summary.NumTargeted = &temp_summary.NumTargeted
// NumRan is the number of hosts that actually ran the script successfully.
summary.NumRan = &temp_summary.NumSucceeded
// NumErrored is the number of hosts that errored out, which includes
// both failed and did not run.
summary.NumErrored = ptr.Uint(temp_summary.NumFailed + temp_summary.NumDidNotRun)
// NumFailed is the number of hosts that were canceled before execution.
summary.NumCanceled = &temp_summary.NumCancelled
// NumPending is the number of hosts that are pending execution.
summary.NumPending = ptr.Uint(temp_summary.NumTargeted - (temp_summary.NumSucceeded + temp_summary.NumFailed + temp_summary.NumDidNotRun + temp_summary.NumCancelled))
// Fill out the script details
if err := sqlx.GetContext(ctx, ds.reader(ctx), &summary, stmtScriptDetails, executionID); err != nil {
return nil, ctxerr.Wrap(ctx, err, "selecting script information for bulk execution summary")
}
if summary.TeamID == nil {
summary.TeamID = ptr.Uint(0)
}
return &summary, nil
}
func (ds *Datastore) ListBatchScriptExecutions(ctx context.Context, filter fleet.BatchExecutionStatusFilter) ([]fleet.BatchActivity, error) {
stmtExecutions := `
SELECT *
FROM (
-- If batch is finished, get the cached host result counts
SELECT
COALESCE(ba.num_targeted, 0) AS num_targeted,
COALESCE(ba.num_incompatible, 0) AS num_incompatible,
COALESCE(ba.num_ran, 0) AS num_ran,
COALESCE(ba.num_errored, 0) AS num_errored,
COALESCE(ba.num_canceled, 0) AS num_canceled,
COALESCE(ba.num_pending, 0) AS num_pending,
ba.execution_id,
ba.script_id,
ba.status,
ba.canceled,
ba.finished_at,
ba.started_at,
s.name AS script_name,
s.global_or_team_id AS team_id,
ba.created_at AS created_at,
j.not_before AS not_before,
ba.id AS id
FROM batch_activities ba
JOIN scripts s ON ba.script_id = s.id
LEFT JOIN jobs j ON j.id = ba.job_id
WHERE ( %s ) AND ba.status = 'finished'
UNION ALL
-- If batch is not finished, calculate the host result counts live.
SELECT
COUNT(bahr.host_id) AS num_targeted,
COUNT(bahr.error) AS num_incompatible,
COUNT(IF(hsr.exit_code = 0, 1, NULL)) AS num_ran,
COUNT(IF(hsr.exit_code <> 0, 1, NULL)) AS num_errored,
COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba.canceled = 1), 1, NULL)) AS num_cancelled,
(
COUNT(bahr.host_id)
- COUNT(bahr.error)
- COUNT(IF(hsr.exit_code = 0, 1, NULL))
- COUNT(IF(hsr.exit_code <> 0, 1, NULL))
- COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba.canceled = 1), 1, NULL))
) AS num_pending,
ba.execution_id,
ba.script_id,
ba.status,
ba.canceled,
ba.finished_at,
ba.started_at,
s.name AS script_name,
s.global_or_team_id AS team_id,
ba.created_at AS created_at,
j.not_before AS not_before,
ba.id AS id
FROM batch_activities ba
LEFT JOIN batch_activity_host_results bahr
ON ba.execution_id = bahr.batch_execution_id
LEFT JOIN host_script_results hsr
ON bahr.host_execution_id = hsr.execution_id
JOIN scripts s
ON ba.script_id = s.id
LEFT JOIN jobs j
ON j.id = ba.job_id
WHERE ( %s ) AND ba.status <> 'finished'
GROUP BY ba.id
) AS u
ORDER BY
%s
LIMIT %d OFFSET %d
`
limit := 10
offset := 0
args := []any{}
orderBy := []string{"u.created_at DESC", "u.id DESC"}
whereClauses := make([]string, 0, 2)
// If an execution ID is provided, use it to filter the results.
if filter.ExecutionID != nil && *filter.ExecutionID != "" {
whereClauses = append(whereClauses, "ba.execution_id = ?")
args = append(args, *filter.ExecutionID)
} else {
// Otherwise filter by status and/or team ID.
if filter.Status != nil && *filter.Status != "" {
whereClauses = append(whereClauses, "ba.status = ?")
args = append(args, *filter.Status)
switch *filter.Status {
case string(fleet.ScheduledBatchExecutionScheduled):
orderBy = append([]string{"u.not_before ASC"}, orderBy...)
case string(fleet.ScheduledBatchExecutionStarted):
orderBy = append([]string{"u.started_at DESC"}, orderBy...)
case string(fleet.ScheduledBatchExecutionFinished):
orderBy = append([]string{"u.finished_at DESC"}, orderBy...)
default:
// no additional ordering
}
}
if filter.TeamID != nil {
whereClauses = append(whereClauses, "s.global_or_team_id = ?")
args = append(args, *filter.TeamID)
}
}
// Double up the args to use them in both WHERE clauses.
args = append(args, args...)
// Use pagination parameters if provided.
if filter.Limit != nil {
limit = int(*filter.Limit) //nolint:gosec // dismiss G115
}
if filter.Offset != nil {
offset = int(*filter.Offset) //nolint:gosec // dismiss G115
}
where := strings.Join(whereClauses, " AND ")
stmtExecutions = fmt.Sprintf(stmtExecutions, where, where, strings.Join(orderBy, ", "), limit, offset)
var summary []fleet.BatchActivity
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &summary, stmtExecutions, args...); err != nil {
return nil, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary")
}
return summary, nil
}
func (ds *Datastore) CountBatchScriptExecutions(ctx context.Context, filter fleet.BatchExecutionStatusFilter) (int64, error) {
stmtExecutions := `
SELECT
COUNT(*)
FROM
batch_activities ba
JOIN
scripts s
ON ba.script_id = s.id
WHERE
%s
`
args := []any{}
whereClauses := make([]string, 0, 2)
if filter.Status != nil && *filter.Status != "" {
whereClauses = append(whereClauses, "ba.status = ?")
args = append(args, *filter.Status)
}
if filter.TeamID != nil {
whereClauses = append(whereClauses, "s.global_or_team_id = ?")
args = append(args, *filter.TeamID)
}
where := strings.Join(whereClauses, " AND ")
stmtExecutions = fmt.Sprintf(stmtExecutions, where)
var count int64
if err := sqlx.GetContext(ctx, ds.reader(ctx), &count, stmtExecutions, args...); err != nil {
return 0, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary")
}
return count, nil
}
func (ds *Datastore) markActivitiesAsCompleted(ctx context.Context, tx sqlx.ExtContext) error {
const stmt = `
UPDATE batch_activities AS ba
JOIN (
SELECT
ba2.id AS batch_id,
COUNT(bahr.host_id) AS num_targeted,
COUNT(bahr.error) AS num_incompatible,
COUNT(IF(hsr.exit_code = 0, 1, NULL)) AS num_ran,
COUNT(IF(hsr.exit_code <> 0, 1, NULL)) AS num_errored,
COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba2.canceled = 1), 1, NULL)) AS num_canceled
FROM batch_activities AS ba2
LEFT JOIN batch_activity_host_results AS bahr
ON ba2.execution_id = bahr.batch_execution_id
LEFT JOIN host_script_results AS hsr
ON bahr.host_execution_id = hsr.execution_id
WHERE ba2.status = 'started'
GROUP BY ba2.id
HAVING (num_incompatible + num_ran + num_errored + num_canceled) >= num_targeted
) AS agg
ON agg.batch_id = ba.id
SET
ba.status = 'finished',
ba.finished_at = NOW(),
ba.num_targeted = agg.num_targeted,
ba.num_incompatible = agg.num_incompatible,
ba.num_ran = agg.num_ran,
ba.num_errored = agg.num_errored,
ba.num_canceled = agg.num_canceled,
ba.num_pending = 0
WHERE ba.status = 'started';
`
// TODO -- use `RETURNING` to return the IDs of the updated activities?
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return ctxerr.Wrap(ctx, err, "marking activities as completed")
}
return nil
}
func (ds *Datastore) MarkActivitiesAsCompleted(ctx context.Context) error {
return ds.markActivitiesAsCompleted(ctx, ds.writer(ctx))
}