mirror of
https://github.com/fleetdm/fleet
synced 2026-04-21 21:47:20 +00:00
<!-- Add the related story/sub-task/bug number, like Resolves #123, or remove if NA --> **Related issue:** Resolves #42799 When a macOS device acknowledges a lock command it can immediately send a trailing Idle check-in. CleanAppleMDMLock now requires that unlock_ref to be set at least 5 minutes ago before clearing the lock state, preventing that trailing Idle to prematurely clearing the MDM lock state.
3075 lines
96 KiB
Go
3075 lines
96 KiB
Go
package mysql
|
|
|
|
import (
|
|
"context"
|
|
"crypto/md5" //nolint:gosec
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
constants "github.com/fleetdm/fleet/v4/pkg/scripts"
|
|
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/fleetdm/fleet/v4/server/ptr"
|
|
"github.com/google/uuid"
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
func (ds *Datastore) NewHostScriptExecutionRequest(ctx context.Context, request *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) {
|
|
var res *fleet.HostScriptResult
|
|
return res, ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
if request.ScriptContentID == 0 {
|
|
// then we are doing a sync execution, so create the contents first
|
|
scRes, err := insertScriptContents(ctx, tx, request.ScriptContents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, _ := scRes.LastInsertId()
|
|
request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115
|
|
}
|
|
res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, false)
|
|
return err
|
|
})
|
|
}
|
|
|
|
func (ds *Datastore) newHostScriptExecutionRequest(ctx context.Context, tx sqlx.ExtContext, request *fleet.HostScriptRequestPayload, isInternal bool) (*fleet.HostScriptResult, error) {
|
|
const (
|
|
getStmt = `
|
|
SELECT
|
|
ua.id, ua.host_id, ua.execution_id, ua.created_at, sua.script_id, sua.policy_id, ua.user_id,
|
|
payload->'$.sync_request' AS sync_request,
|
|
sc.contents as script_contents, sua.setup_experience_script_id
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
INNER JOIN script_contents sc
|
|
ON sua.script_content_id = sc.id
|
|
WHERE
|
|
ua.id = ?
|
|
`
|
|
)
|
|
|
|
_, activityID, err := ds.insertNewHostScriptExecution(ctx, tx, request, isInternal)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "inserting new script execution request")
|
|
}
|
|
|
|
var script fleet.HostScriptResult
|
|
err = sqlx.GetContext(ctx, tx, &script, getStmt, activityID)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "getting the created host script activity to return")
|
|
}
|
|
|
|
return &script, nil
|
|
}
|
|
|
|
func (ds *Datastore) insertNewHostScriptExecution(ctx context.Context, tx sqlx.ExtContext, request *fleet.HostScriptRequestPayload, isInternal bool) (string, int64, error) {
|
|
const (
|
|
insUAStmt = `
|
|
INSERT INTO upcoming_activities
|
|
(host_id, priority, user_id, fleet_initiated, activity_type, execution_id, payload)
|
|
VALUES
|
|
(?, ?, ?, ?, 'script', ?,
|
|
JSON_OBJECT(
|
|
'sync_request', ?,
|
|
'is_internal', ?,
|
|
'user', (SELECT JSON_OBJECT('name', name, 'email', email, 'gravatar_url', gravatar_url) FROM users WHERE id = ?)
|
|
)
|
|
)`
|
|
|
|
insSUAStmt = `
|
|
INSERT INTO script_upcoming_activities
|
|
(upcoming_activity_id, script_id, script_content_id, policy_id, setup_experience_script_id)
|
|
VALUES
|
|
(?, ?, ?, ?, ?)
|
|
`
|
|
)
|
|
|
|
execID := uuid.New().String()
|
|
result, err := tx.ExecContext(ctx, insUAStmt,
|
|
request.HostID,
|
|
request.Priority(),
|
|
request.UserID,
|
|
request.PolicyID != nil, // fleet-initiated if request is via a policy failure
|
|
execID,
|
|
request.SyncRequest,
|
|
isInternal,
|
|
request.UserID,
|
|
)
|
|
if err != nil {
|
|
return "", 0, ctxerr.Wrap(ctx, err, "new script upcoming activity")
|
|
}
|
|
|
|
activityID, _ := result.LastInsertId()
|
|
_, err = tx.ExecContext(ctx, insSUAStmt,
|
|
activityID,
|
|
request.ScriptID,
|
|
request.ScriptContentID,
|
|
request.PolicyID,
|
|
request.SetupExperienceScriptID,
|
|
)
|
|
if err != nil {
|
|
return "", 0, ctxerr.Wrap(ctx, err, "new join script upcoming activity")
|
|
}
|
|
|
|
if _, err := ds.activateNextUpcomingActivity(ctx, tx, request.HostID, ""); err != nil {
|
|
return "", 0, ctxerr.Wrap(ctx, err, "activate next activity")
|
|
}
|
|
|
|
return execID, activityID, nil
|
|
}
|
|
|
|
func truncateScriptResult(output string) string {
|
|
const maxOutputRuneLen = 10000
|
|
if len(output) > utf8.UTFMax*maxOutputRuneLen {
|
|
// truncate the bytes as we know the output is too long, no point
|
|
// converting more bytes than needed to runes.
|
|
output = output[len(output)-(utf8.UTFMax*maxOutputRuneLen):]
|
|
}
|
|
if utf8.RuneCountInString(output) > maxOutputRuneLen {
|
|
outputRunes := []rune(output)
|
|
output = string(outputRunes[len(outputRunes)-maxOutputRuneLen:])
|
|
}
|
|
return output
|
|
}
|
|
|
|
func (ds *Datastore) SetHostScriptExecutionResult(ctx context.Context, result *fleet.HostScriptResultPayload, attemptNumber *int) (*fleet.HostScriptResult,
|
|
string, error,
|
|
) {
|
|
const resultExistsStmt = `
|
|
SELECT
|
|
1
|
|
FROM
|
|
host_script_results
|
|
WHERE
|
|
host_id = ? AND
|
|
execution_id = ? AND
|
|
exit_code IS NOT NULL
|
|
`
|
|
|
|
const updStmt = `
|
|
UPDATE host_script_results SET
|
|
output = ?,
|
|
runtime = ?,
|
|
exit_code = ?,
|
|
timeout = ?,
|
|
attempt_number = ?
|
|
WHERE
|
|
host_id = ? AND
|
|
execution_id = ?`
|
|
|
|
const hostMDMActionsStmt = `
|
|
SELECT 'uninstall' AS action
|
|
FROM
|
|
host_software_installs
|
|
WHERE
|
|
execution_id = :execution_id AND host_id = :host_id
|
|
UNION -- host_mdm_actions query (and thus row in union) must be last to avoid #25144
|
|
SELECT
|
|
CASE
|
|
WHEN lock_ref = :execution_id THEN 'lock_ref'
|
|
WHEN unlock_ref = :execution_id THEN 'unlock_ref'
|
|
WHEN wipe_ref = :execution_id THEN 'wipe_ref'
|
|
ELSE ''
|
|
END AS action
|
|
FROM
|
|
host_mdm_actions
|
|
WHERE
|
|
host_id = :host_id
|
|
`
|
|
|
|
output := truncateScriptResult(result.Output)
|
|
|
|
var hsr *fleet.HostScriptResult
|
|
var action string
|
|
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var resultExists bool
|
|
err := sqlx.GetContext(ctx, tx, &resultExists, resultExistsStmt, result.HostID, result.ExecutionID)
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
|
return ctxerr.Wrap(ctx, err, "check if host script result exists")
|
|
}
|
|
if resultExists {
|
|
ds.logger.DebugContext(ctx, "duplicate script execution result sent, will be ignored (original result is preserved)",
|
|
"host_id", result.HostID,
|
|
"execution_id", result.ExecutionID,
|
|
)
|
|
|
|
// still do the activate next activity to ensure progress as there was
|
|
// an unexpected flow if we get here.
|
|
if _, err := ds.activateNextUpcomingActivity(ctx, tx, result.HostID, result.ExecutionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "activate next activity")
|
|
}
|
|
|
|
// succeed but leave hsr nil
|
|
return nil
|
|
}
|
|
|
|
res, err := tx.ExecContext(ctx, updStmt,
|
|
output,
|
|
result.Runtime,
|
|
// Windows error codes are signed 32-bit integers, but are
|
|
// returned as unsigned integers by the windows API. The
|
|
// software that receives them is responsible for casting
|
|
// it to a 32-bit signed integer.
|
|
// See /orbit/pkg/scripts/exec_windows.go
|
|
int32(result.ExitCode), //nolint:gosec // dismiss G115
|
|
result.Timeout,
|
|
attemptNumber,
|
|
result.HostID,
|
|
result.ExecutionID,
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "update host script result")
|
|
}
|
|
|
|
if n, _ := res.RowsAffected(); n > 0 {
|
|
// it did update, so return the updated result
|
|
hsr, err = ds.getHostScriptExecutionResultDB(ctx, tx, result.ExecutionID, scriptExecutionSearchOpts{IncludeCanceled: true})
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load updated host script result")
|
|
}
|
|
|
|
// look up if that script was a lock/unlock/wipe/uninstall script for that host,
|
|
// and if so update the host_mdm_actions table accordingly.
|
|
namedArgs := map[string]any{
|
|
"host_id": result.HostID,
|
|
"execution_id": result.ExecutionID,
|
|
}
|
|
stmt, args, err := sqlx.Named(hostMDMActionsStmt, namedArgs)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build named query for host mdm actions")
|
|
}
|
|
err = sqlx.GetContext(ctx, tx, &action, stmt, args...)
|
|
if err != nil && !errors.Is(err, sql.ErrNoRows) { // ignore ErrNoRows, refCol will be empty
|
|
return ctxerr.Wrap(ctx, err, "lookup host script corresponding mdm action")
|
|
}
|
|
|
|
switch action {
|
|
case "":
|
|
// do nothing
|
|
case "uninstall":
|
|
err = ds.updateUninstallStatusFromResult(ctx, tx, result.HostID, result.ExecutionID, result.ExitCode)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "update host uninstall action based on script result")
|
|
}
|
|
default: // lock/unlock/wipe
|
|
err = updateHostLockWipeStatusFromResult(ctx, tx, result.HostID, action, result.ExitCode == 0)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "update host mdm action based on script result")
|
|
}
|
|
}
|
|
}
|
|
|
|
if _, err := ds.activateNextUpcomingActivity(ctx, tx, result.HostID, result.ExecutionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "activate next activity")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, "", err
|
|
}
|
|
return hsr, action, nil
|
|
}
|
|
|
|
func (ds *Datastore) ListPendingHostScriptExecutions(ctx context.Context, hostID uint, onlyShowInternal bool) ([]*fleet.HostScriptResult, error) {
|
|
return ds.listUpcomingHostScriptExecutions(ctx, hostID, onlyShowInternal, false)
|
|
}
|
|
|
|
func (ds *Datastore) ListReadyToExecuteScriptsForHost(ctx context.Context, hostID uint, onlyShowInternal bool) ([]*fleet.HostScriptResult, error) {
|
|
return ds.listUpcomingHostScriptExecutions(ctx, hostID, onlyShowInternal, true)
|
|
}
|
|
|
|
func (ds *Datastore) listUpcomingHostScriptExecutions(ctx context.Context, hostID uint, onlyShowInternal, onlyReadyToExecute bool) ([]*fleet.HostScriptResult, error) {
|
|
extraWhere := ""
|
|
if onlyShowInternal {
|
|
// software_uninstalls are implicitly internal
|
|
extraWhere = " AND COALESCE(ua.payload->'$.is_internal', 1) = 1"
|
|
}
|
|
if onlyReadyToExecute {
|
|
extraWhere += " AND ua.activated_at IS NOT NULL"
|
|
}
|
|
// this selects software uninstalls too as they run as scripts
|
|
listStmt := fmt.Sprintf(`
|
|
SELECT
|
|
id,
|
|
host_id,
|
|
execution_id,
|
|
script_id,
|
|
created_at
|
|
FROM (
|
|
SELECT
|
|
ua.id,
|
|
ua.host_id,
|
|
ua.execution_id,
|
|
sua.script_id,
|
|
ua.priority,
|
|
ua.created_at,
|
|
IF(ua.activated_at IS NULL, 0, 1) AS topmost
|
|
FROM
|
|
upcoming_activities ua
|
|
-- left join because software_uninstall has no script join
|
|
LEFT JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.host_id = ? AND
|
|
ua.activity_type IN ('script', 'software_uninstall')
|
|
%s
|
|
ORDER BY topmost DESC, priority DESC, created_at ASC) t`, extraWhere)
|
|
|
|
var results []*fleet.HostScriptResult
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, listStmt, hostID); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "list pending host script executions")
|
|
}
|
|
return results, nil
|
|
}
|
|
|
|
func (ds *Datastore) IsExecutionPendingForHost(ctx context.Context, hostID uint, scriptID uint) (bool, error) {
|
|
const getStmt = `
|
|
SELECT
|
|
1
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.host_id = ? AND
|
|
ua.activity_type = 'script' AND
|
|
sua.script_id = ?
|
|
`
|
|
|
|
var results []*uint
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, getStmt, hostID, scriptID); err != nil {
|
|
return false, ctxerr.Wrap(ctx, err, "is execution pending for host")
|
|
}
|
|
return len(results) > 0, nil
|
|
}
|
|
|
|
type scriptExecutionSearchOpts struct {
|
|
IncludeCanceled bool
|
|
UninstallHostID uint
|
|
}
|
|
|
|
func (ds *Datastore) GetHostScriptExecutionResult(ctx context.Context, execID string) (*fleet.HostScriptResult, error) {
|
|
return ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), execID, scriptExecutionSearchOpts{})
|
|
}
|
|
|
|
func (ds *Datastore) GetSelfServiceUninstallScriptExecutionResult(ctx context.Context, execID string, hostID uint) (*fleet.HostScriptResult, error) {
|
|
return ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), execID, scriptExecutionSearchOpts{UninstallHostID: hostID})
|
|
}
|
|
|
|
func (ds *Datastore) getHostScriptExecutionResultDB(ctx context.Context, q sqlx.QueryerContext, execID string, opts scriptExecutionSearchOpts) (*fleet.HostScriptResult, error) {
|
|
var activeParams []any
|
|
|
|
canceledCondition := ""
|
|
if !opts.IncludeCanceled {
|
|
canceledCondition = " AND hsr.canceled = 0"
|
|
}
|
|
|
|
uninstallCondition := ""
|
|
if opts.UninstallHostID > 0 {
|
|
uninstallCondition = `JOIN host_software_installs hsi ON hsi.execution_id = hsr.execution_id
|
|
AND hsi.uninstall = TRUE AND hsr.host_id = ?`
|
|
activeParams = append(activeParams, opts.UninstallHostID)
|
|
}
|
|
|
|
activeParams = append(activeParams, execID)
|
|
|
|
getActiveStmt := fmt.Sprintf(`
|
|
SELECT
|
|
hsr.id,
|
|
hsr.host_id,
|
|
hsr.execution_id,
|
|
sc.contents as script_contents,
|
|
hsr.script_id,
|
|
hsr.policy_id,
|
|
hsr.output,
|
|
hsr.runtime,
|
|
hsr.exit_code,
|
|
hsr.timeout,
|
|
hsr.created_at,
|
|
hsr.user_id,
|
|
hsr.sync_request,
|
|
hsr.host_deleted_at,
|
|
hsr.setup_experience_script_id,
|
|
hsr.canceled,
|
|
bahr.batch_execution_id,
|
|
hsr.attempt_number
|
|
FROM
|
|
host_script_results hsr
|
|
LEFT JOIN
|
|
batch_activity_host_results bahr ON hsr.execution_id = bahr.host_execution_id
|
|
JOIN
|
|
script_contents sc
|
|
%s
|
|
WHERE
|
|
hsr.execution_id = ? AND
|
|
hsr.script_content_id = sc.id
|
|
%s
|
|
`, uninstallCondition, canceledCondition)
|
|
|
|
// We don't include upcoming uninstall script executions in results (different activity type, and they're blank anyway)
|
|
const getUpcomingStmt = `
|
|
SELECT
|
|
0 as id,
|
|
ua.host_id,
|
|
ua.execution_id,
|
|
sc.contents as script_contents,
|
|
sua.script_id,
|
|
sua.policy_id,
|
|
'' as output,
|
|
0 as runtime,
|
|
NULL as exit_code,
|
|
NULL as timeout,
|
|
ua.created_at,
|
|
ua.user_id,
|
|
COALESCE(ua.payload->'$.sync_request', 0) as sync_request,
|
|
NULL as host_deleted_at,
|
|
sua.setup_experience_script_id,
|
|
0 as canceled,
|
|
NULL as batch_execution_id,
|
|
NULL as attempt_number
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
INNER JOIN
|
|
script_contents sc
|
|
ON sua.script_content_id = sc.id
|
|
WHERE
|
|
ua.execution_id = ? AND
|
|
ua.activity_type = 'script'
|
|
`
|
|
|
|
var result fleet.HostScriptResult
|
|
if err := sqlx.GetContext(ctx, q, &result, getActiveStmt, activeParams...); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
// try with upcoming activities
|
|
err = sqlx.GetContext(ctx, q, &result, getUpcomingStmt, execID)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ctxerr.Wrap(ctx, notFound("HostScriptResult").WithName(execID))
|
|
}
|
|
}
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "get host script result")
|
|
}
|
|
}
|
|
return &result, nil
|
|
}
|
|
|
|
func (ds *Datastore) CountHostScriptAttempts(ctx context.Context, hostID, scriptID, policyID uint) (int, error) {
|
|
var count int
|
|
// Only count attempts from the current retry sequence.
|
|
// When a policy passes, all attempt_number values are reset to 0 to mark them as "old sequence".
|
|
// We count attempts where attempt_number > 0 (current sequence) OR attempt_number IS NULL (currently being processed).
|
|
err := sqlx.GetContext(ctx, ds.reader(ctx), &count, `
|
|
SELECT COUNT(*)
|
|
FROM host_script_results
|
|
WHERE host_id = ?
|
|
AND script_id = ?
|
|
AND policy_id = ?
|
|
AND canceled = 0
|
|
AND (attempt_number > 0 OR attempt_number IS NULL)
|
|
`, hostID, scriptID, policyID)
|
|
if err != nil {
|
|
return 0, ctxerr.Wrap(ctx, err, "count host script attempts")
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
func (ds *Datastore) NewScript(ctx context.Context, script *fleet.Script) (*fleet.Script, error) {
|
|
var res sql.Result
|
|
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
|
|
// first insert script contents
|
|
scRes, err := insertScriptContents(ctx, tx, script.ScriptContents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
id, _ := scRes.LastInsertId()
|
|
|
|
// then create the script entity
|
|
res, err = insertScript(ctx, tx, script, uint(id)) //nolint:gosec // dismiss G115
|
|
return err
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
id, _ := res.LastInsertId()
|
|
return ds.getScriptDB(ctx, ds.writer(ctx), uint(id)) //nolint:gosec // dismiss G115
|
|
}
|
|
|
|
func (ds *Datastore) UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error) {
|
|
err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
// Get the current script_content_id
|
|
var oldContentID int64
|
|
getCurrentStmt := `SELECT script_content_id FROM scripts WHERE id = ?`
|
|
err := sqlx.GetContext(ctx, tx, &oldContentID, getCurrentStmt, scriptID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "getting current script content id")
|
|
}
|
|
|
|
// Insert or get existing content (insertScriptContents handles deduplication)
|
|
scRes, err := insertScriptContents(ctx, tx, scriptContents)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "inserting/getting script contents")
|
|
}
|
|
newContentID, _ := scRes.LastInsertId()
|
|
|
|
// Update the script to point to the new content
|
|
if newContentID != oldContentID {
|
|
updateStmt := `
|
|
UPDATE scripts
|
|
SET script_content_id = ?
|
|
WHERE id = ?
|
|
`
|
|
_, err = tx.ExecContext(ctx, updateStmt, newContentID, scriptID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "updating script content reference")
|
|
}
|
|
|
|
// Try to clean up the old content if no longer used
|
|
// Don't fail the transaction if cleanup fails; just log it
|
|
if err := ds.cleanupScriptContent(ctx, tx, uint(oldContentID)); err != nil { //nolint:gosec
|
|
ds.logger.ErrorContext(ctx, "failed to cleanup orphaned script content",
|
|
"script_id", scriptID, "old_content_id", oldContentID, "err", err)
|
|
ctxerr.Handle(ctx, err)
|
|
}
|
|
} else {
|
|
// Just update the timestamp
|
|
_, err = tx.ExecContext(ctx, "UPDATE scripts SET updated_at = NOW() WHERE id = ?", scriptID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "updating script updated_at time")
|
|
}
|
|
}
|
|
|
|
// Cancel pending executions
|
|
if err := ds.cancelUpcomingScriptActivities(ctx, tx, scriptID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "canceling upcoming script executions")
|
|
}
|
|
|
|
// When a script is modified reset attempt numbers for policy automations
|
|
if err := ds.resetScriptPolicyAutomationAttempts(ctx, tx, scriptID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "resetting policy automation attempts for script")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "updating script contents")
|
|
}
|
|
return ds.Script(ctx, scriptID)
|
|
}
|
|
|
|
func (ds *Datastore) cancelUpcomingScriptActivities(ctx context.Context, db sqlx.ExtContext, scriptID uint) error {
|
|
const stmt = `
|
|
SELECT
|
|
ua.execution_id,
|
|
ua.host_id
|
|
FROM
|
|
script_upcoming_activities sua
|
|
INNER JOIN
|
|
upcoming_activities ua ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
sua.script_id = ?
|
|
`
|
|
|
|
var upcomingExecutions []struct {
|
|
ExecutionID string `db:"execution_id"`
|
|
HostID uint `db:"host_id"`
|
|
}
|
|
|
|
if err := sqlx.SelectContext(ctx, db, &upcomingExecutions, stmt, scriptID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "selecting upcoming script executions")
|
|
}
|
|
|
|
for _, upcomingExecution := range upcomingExecutions {
|
|
if _, err := ds.cancelHostUpcomingActivity(ctx, db, upcomingExecution.HostID, upcomingExecution.ExecutionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "canceling upcoming activity")
|
|
}
|
|
}
|
|
|
|
// Cancel scripts that were already activated and are in host_script_results but not yet executed
|
|
const activatedStmt = `UPDATE host_script_results SET canceled = 1 WHERE script_id = ? AND exit_code IS NULL AND canceled = 0`
|
|
if _, err := db.ExecContext(ctx, activatedStmt, scriptID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "canceling activated pending script executions")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// resetScriptPolicyAutomationAttempts resets all attempt numbers for script executions for policy automations
|
|
func (ds *Datastore) resetScriptPolicyAutomationAttempts(ctx context.Context, db sqlx.ExecerContext, scriptID uint) error {
|
|
_, err := db.ExecContext(ctx, `
|
|
UPDATE host_script_results
|
|
SET attempt_number = 0
|
|
WHERE script_id = ? AND policy_id IS NOT NULL AND (attempt_number > 0 OR attempt_number IS NULL)
|
|
`, scriptID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "reset policy automation script attempts")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func insertScript(ctx context.Context, tx sqlx.ExtContext, script *fleet.Script, scriptContentsID uint) (sql.Result, error) {
|
|
const insertStmt = `
|
|
INSERT INTO
|
|
scripts (
|
|
team_id, global_or_team_id, name, script_content_id
|
|
)
|
|
VALUES
|
|
(?, ?, ?, ?)
|
|
`
|
|
var globalOrTeamID uint
|
|
if script.TeamID != nil {
|
|
globalOrTeamID = *script.TeamID
|
|
}
|
|
res, err := tx.ExecContext(ctx, insertStmt,
|
|
script.TeamID, globalOrTeamID, script.Name, scriptContentsID)
|
|
if err != nil {
|
|
if IsDuplicate(err) {
|
|
// name already exists for this team/global
|
|
err = alreadyExists("Script", script.Name)
|
|
} else if isChildForeignKeyError(err) {
|
|
// team does not exist
|
|
err = foreignKey("scripts", fmt.Sprintf("team_id=%v", script.TeamID))
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "insert script")
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
func insertScriptContents(ctx context.Context, tx sqlx.ExtContext, contents string) (sql.Result, error) {
|
|
const insertStmt = `
|
|
INSERT INTO
|
|
script_contents (
|
|
md5_checksum, contents
|
|
)
|
|
VALUES (UNHEX(?),?)
|
|
ON DUPLICATE KEY UPDATE
|
|
id=LAST_INSERT_ID(id)
|
|
`
|
|
|
|
md5Checksum := md5ChecksumScriptContent(contents)
|
|
res, err := tx.ExecContext(ctx, insertStmt, md5Checksum, contents)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "insert script contents")
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func md5ChecksumScriptContent(s string) string {
|
|
return md5ChecksumBytes([]byte(s))
|
|
}
|
|
|
|
func md5ChecksumBytes(b []byte) string {
|
|
rawChecksum := md5.Sum(b) //nolint:gosec
|
|
return strings.ToUpper(hex.EncodeToString(rawChecksum[:]))
|
|
}
|
|
|
|
func (ds *Datastore) cleanupScriptContent(ctx context.Context, tx sqlx.ExtContext, contentID uint) error {
|
|
// Check if this content is still being used anywhere
|
|
var usageCount int
|
|
stmt := `
|
|
SELECT COUNT(*) FROM (
|
|
SELECT 1 FROM scripts WHERE script_content_id = ?
|
|
UNION ALL
|
|
SELECT 1 FROM setup_experience_scripts WHERE script_content_id = ?
|
|
UNION ALL
|
|
SELECT 1 FROM software_installers WHERE
|
|
install_script_content_id = ?
|
|
OR uninstall_script_content_id = ?
|
|
OR post_install_script_content_id = ?
|
|
UNION ALL
|
|
SELECT 1 FROM script_upcoming_activities WHERE script_content_id = ?
|
|
UNION ALL
|
|
SELECT 1 FROM host_script_results WHERE script_content_id = ?
|
|
) t
|
|
`
|
|
err := sqlx.GetContext(ctx, tx, &usageCount, stmt,
|
|
contentID, contentID, contentID, contentID, contentID, contentID, contentID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "checking script content usage for cleanup")
|
|
}
|
|
|
|
if usageCount == 0 {
|
|
// Not being used, safe to delete
|
|
deleteStmt := `DELETE FROM script_contents WHERE id = ?`
|
|
_, err = tx.ExecContext(ctx, deleteStmt, contentID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "deleting unused script content")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) Script(ctx context.Context, id uint) (*fleet.Script, error) {
|
|
return ds.getScriptDB(ctx, ds.reader(ctx), id)
|
|
}
|
|
|
|
func (ds *Datastore) getScriptDB(ctx context.Context, q sqlx.QueryerContext, id uint) (*fleet.Script, error) {
|
|
const getStmt = `
|
|
SELECT
|
|
id,
|
|
team_id,
|
|
name,
|
|
created_at,
|
|
updated_at,
|
|
script_content_id
|
|
FROM
|
|
scripts
|
|
WHERE
|
|
id = ?
|
|
`
|
|
var script fleet.Script
|
|
if err := sqlx.GetContext(ctx, q, &script, getStmt, id); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, notFound("Script").WithID(id)
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "get script")
|
|
}
|
|
return &script, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetScriptContents(ctx context.Context, id uint) ([]byte, error) {
|
|
const getStmt = `
|
|
SELECT
|
|
sc.contents
|
|
FROM
|
|
script_contents sc
|
|
JOIN scripts s ON s.script_content_id = sc.id
|
|
WHERE
|
|
s.id = ?
|
|
`
|
|
var contents []byte
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &contents, getStmt, id); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, notFound("Script").WithID(id)
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "get script contents")
|
|
}
|
|
return contents, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetAnyScriptContents(ctx context.Context, id uint) ([]byte, error) {
|
|
const getStmt = `
|
|
SELECT
|
|
sc.contents
|
|
FROM
|
|
script_contents sc
|
|
WHERE
|
|
sc.id = ?
|
|
`
|
|
var contents []byte
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &contents, getStmt, id); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, notFound("Script").WithID(id)
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "get any script contents")
|
|
}
|
|
return contents, nil
|
|
}
|
|
|
|
var errDeleteScriptWithAssociatedPolicy = &fleet.ConflictError{Message: "Couldn't delete. Policy automation uses this script. Please remove this script from associated policy automations and try again."}
|
|
|
|
func (ds *Datastore) DeleteScript(ctx context.Context, id uint) error {
|
|
var activateAffectedHosts []uint
|
|
|
|
err := ds.withTx(ctx, func(tx sqlx.ExtContext) error {
|
|
_, err := tx.ExecContext(ctx, `DELETE FROM host_script_results WHERE script_id = ?
|
|
AND exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)`,
|
|
id, int(constants.MaxServerWaitTime.Seconds()),
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "cancel pending script executions")
|
|
}
|
|
|
|
// load hosts that will have their upcoming_activities deleted, if that
|
|
// activity is "activated", as that means we will have to call
|
|
// activateNextUpcomingActivity for those hosts.
|
|
loadAffectedHostsStmt := `
|
|
SELECT
|
|
DISTINCT host_id
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE sua.script_id = ? AND
|
|
ua.activity_type = 'script' AND
|
|
ua.activated_at IS NOT NULL AND
|
|
(ua.payload->'$.sync_request' = 0 OR
|
|
ua.created_at >= NOW() - INTERVAL ? SECOND)`
|
|
var affectedHosts []uint
|
|
if err := sqlx.SelectContext(ctx, tx, &affectedHosts, loadAffectedHostsStmt,
|
|
id, int(constants.MaxServerWaitTime.Seconds())); err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "load affected hosts")
|
|
}
|
|
activateAffectedHosts = affectedHosts
|
|
|
|
_, err = tx.ExecContext(ctx, `DELETE FROM upcoming_activities
|
|
USING upcoming_activities
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON upcoming_activities.id = sua.upcoming_activity_id
|
|
WHERE sua.script_id = ? AND
|
|
upcoming_activities.activity_type = 'script' AND
|
|
(upcoming_activities.payload->'$.sync_request' = 0 OR
|
|
upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND)
|
|
`,
|
|
id, int(constants.MaxServerWaitTime.Seconds()),
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "cancel upcoming pending script executions")
|
|
}
|
|
|
|
_, err = tx.ExecContext(ctx, `DELETE FROM scripts WHERE id = ?`, id)
|
|
if err != nil {
|
|
if isMySQLForeignKey(err) {
|
|
// Check if the script is referenced by a policy automation.
|
|
var count int
|
|
if err := sqlx.GetContext(ctx, tx, &count, `SELECT COUNT(*) FROM policies WHERE script_id = ?`, id); err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "getting reference from policies")
|
|
}
|
|
if count > 0 {
|
|
return ctxerr.Wrap(ctx, errDeleteScriptWithAssociatedPolicy, "delete script")
|
|
}
|
|
}
|
|
return ctxerr.Wrap(ctx, err, "delete script")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// we call this outside of the transaction to avoid a
|
|
// long-running/deadlock-prone transaction, as many hosts could be affected.
|
|
return ds.activateNextUpcomingActivityForBatchOfHosts(ctx, activateAffectedHosts)
|
|
}
|
|
|
|
// deletePendingHostScriptExecutionsForPolicy should be called when a policy is deleted to remove any pending script executions
|
|
func (ds *Datastore) deletePendingHostScriptExecutionsForPolicy(ctx context.Context, teamID *uint, policyID uint) error {
|
|
var globalOrTeamID uint
|
|
if teamID != nil {
|
|
globalOrTeamID = *teamID
|
|
}
|
|
|
|
deletePendingFunc := func(stmt string, args ...any) error {
|
|
_, err := ds.writer(ctx).ExecContext(ctx, stmt, args...)
|
|
return ctxerr.Wrap(ctx, err, "delete pending host script executions for policy")
|
|
}
|
|
|
|
deleteHSRStmt := `
|
|
DELETE FROM
|
|
host_script_results
|
|
WHERE
|
|
policy_id = ? AND
|
|
script_id IN (
|
|
SELECT id FROM scripts WHERE scripts.global_or_team_id = ?
|
|
) AND
|
|
exit_code IS NULL
|
|
`
|
|
|
|
if err := deletePendingFunc(deleteHSRStmt, policyID, globalOrTeamID); err != nil {
|
|
return err
|
|
}
|
|
|
|
loadAffectedHostsStmt := `
|
|
SELECT
|
|
DISTINCT host_id
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.activity_type = 'script' AND
|
|
ua.activated_at IS NOT NULL AND
|
|
sua.policy_id = ? AND
|
|
sua.script_id IN (
|
|
SELECT id FROM scripts WHERE scripts.global_or_team_id = ?
|
|
)`
|
|
var affectedHosts []uint
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &affectedHosts,
|
|
loadAffectedHostsStmt, policyID, globalOrTeamID); err != nil {
|
|
return err
|
|
}
|
|
|
|
deleteUAStmt := `
|
|
DELETE FROM
|
|
upcoming_activities
|
|
USING
|
|
upcoming_activities
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON upcoming_activities.id = sua.upcoming_activity_id
|
|
WHERE
|
|
upcoming_activities.activity_type = 'script' AND
|
|
sua.policy_id = ? AND
|
|
sua.script_id IN (
|
|
SELECT id FROM scripts WHERE scripts.global_or_team_id = ?
|
|
)
|
|
`
|
|
if err := deletePendingFunc(deleteUAStmt, policyID, globalOrTeamID); err != nil {
|
|
return err
|
|
}
|
|
|
|
return ds.activateNextUpcomingActivityForBatchOfHosts(ctx, affectedHosts)
|
|
}
|
|
|
|
func (ds *Datastore) ListScripts(ctx context.Context, teamID *uint, opt fleet.ListOptions) ([]*fleet.Script, *fleet.PaginationMetadata, error) {
|
|
var scripts []*fleet.Script
|
|
|
|
const selectStmt = `
|
|
SELECT
|
|
s.id,
|
|
s.team_id,
|
|
s.name,
|
|
s.created_at,
|
|
s.updated_at
|
|
FROM
|
|
scripts s
|
|
WHERE
|
|
s.global_or_team_id = ?
|
|
`
|
|
var globalOrTeamID uint
|
|
if teamID != nil {
|
|
globalOrTeamID = *teamID
|
|
}
|
|
|
|
args := []any{globalOrTeamID}
|
|
stmt, args := appendListOptionsWithCursorToSQL(selectStmt, args, &opt)
|
|
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &scripts, stmt, args...); err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "select scripts")
|
|
}
|
|
|
|
var metaData *fleet.PaginationMetadata
|
|
if opt.IncludeMetadata {
|
|
metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0}
|
|
if len(scripts) > int(opt.PerPage) { //nolint:gosec // dismiss G115
|
|
metaData.HasNextResults = true
|
|
scripts = scripts[:len(scripts)-1]
|
|
}
|
|
}
|
|
return scripts, metaData, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetScriptIDByName(ctx context.Context, name string, teamID *uint) (uint, error) {
|
|
const selectStmt = `
|
|
SELECT
|
|
id
|
|
FROM
|
|
scripts
|
|
WHERE
|
|
global_or_team_id = ?
|
|
AND name = ?
|
|
`
|
|
var globalOrTeamID uint
|
|
if teamID != nil {
|
|
globalOrTeamID = *teamID
|
|
}
|
|
|
|
var id uint
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &id, selectStmt, globalOrTeamID, name); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return 0, notFound("Script").WithName(name)
|
|
}
|
|
return 0, ctxerr.Wrap(ctx, err, "get script by name")
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetHostScriptDetails(ctx context.Context, hostID uint, teamID *uint, opt fleet.ListOptions, hostPlatform string) ([]*fleet.HostScriptDetail, *fleet.PaginationMetadata, error) {
|
|
var globalOrTeamID uint
|
|
if teamID != nil {
|
|
globalOrTeamID = *teamID
|
|
}
|
|
|
|
var extensionPatterns []string
|
|
switch {
|
|
case hostPlatform == "windows":
|
|
// filter by .ps1 extension
|
|
extensionPatterns = []string{`%.ps1`}
|
|
case fleet.IsUnixLike(hostPlatform):
|
|
// filter by .sh and .py extensions
|
|
extensionPatterns = []string{`%.sh`, `%.py`}
|
|
default:
|
|
// no extension filter
|
|
}
|
|
|
|
type row struct {
|
|
ScriptID uint `db:"script_id"`
|
|
Name string `db:"name"`
|
|
HSRID *uint `db:"hsr_id"`
|
|
ExecutionID *string `db:"execution_id"`
|
|
ExecutedAt *time.Time `db:"executed_at"`
|
|
ExitCode *int64 `db:"exit_code"`
|
|
}
|
|
|
|
sql := `
|
|
WITH all_latest_activities AS (
|
|
-- Use window function to efficiently find the latest execution per script
|
|
-- This is O(n) (a self-join approach would be O(n²))
|
|
SELECT * FROM (
|
|
SELECT
|
|
id,
|
|
host_id,
|
|
script_id,
|
|
execution_id,
|
|
created_at,
|
|
exit_code,
|
|
'completed' as source,
|
|
ROW_NUMBER() OVER (
|
|
PARTITION BY script_id
|
|
ORDER BY created_at DESC, id DESC
|
|
) AS row_num
|
|
FROM
|
|
host_script_results
|
|
WHERE
|
|
host_id = ? AND
|
|
canceled = 0
|
|
) completed_ranked
|
|
WHERE row_num = 1
|
|
|
|
UNION ALL
|
|
|
|
-- latest from upcoming_activities
|
|
SELECT * FROM (
|
|
SELECT
|
|
NULL as id,
|
|
ua.host_id,
|
|
sua.script_id,
|
|
ua.execution_id,
|
|
ua.created_at,
|
|
NULL as exit_code,
|
|
'upcoming' as source,
|
|
ROW_NUMBER() OVER (
|
|
PARTITION BY sua.script_id
|
|
ORDER BY ua.created_at DESC, ua.id DESC
|
|
) AS row_num
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.host_id = ? AND
|
|
ua.activity_type = 'script'
|
|
) upcoming_ranked
|
|
WHERE row_num = 1
|
|
)
|
|
SELECT
|
|
s.id AS script_id,
|
|
s.name,
|
|
latest.id AS hsr_id,
|
|
latest.created_at AS executed_at,
|
|
latest.execution_id,
|
|
latest.exit_code
|
|
FROM
|
|
scripts s
|
|
LEFT JOIN (
|
|
-- Pick the most recent between completed and upcoming for each script
|
|
SELECT * FROM (
|
|
SELECT
|
|
*,
|
|
ROW_NUMBER() OVER (
|
|
PARTITION BY script_id
|
|
ORDER BY
|
|
CASE WHEN source = 'upcoming' THEN 1 ELSE 2 END, -- Prefer upcoming over completed
|
|
created_at DESC,
|
|
id DESC
|
|
) AS final_rn
|
|
FROM all_latest_activities
|
|
) final_ranked
|
|
WHERE final_rn = 1
|
|
) latest
|
|
ON s.id = latest.script_id
|
|
WHERE
|
|
(latest.host_id IS NULL OR latest.host_id = ?)
|
|
AND s.global_or_team_id = ?
|
|
`
|
|
|
|
args := []any{hostID, hostID, hostID, globalOrTeamID}
|
|
if len(extensionPatterns) > 0 {
|
|
likeClauses := make([]string, 0, len(extensionPatterns))
|
|
for _, ext := range extensionPatterns {
|
|
likeClauses = append(likeClauses, "s.name LIKE ?")
|
|
args = append(args, ext)
|
|
}
|
|
sql += `
|
|
AND (
|
|
` + strings.Join(likeClauses, `
|
|
OR
|
|
`) + `
|
|
)
|
|
`
|
|
}
|
|
stmt, args := appendListOptionsWithCursorToSQL(sql, args, &opt)
|
|
|
|
var rows []*row
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &rows, stmt, args...); err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get host script details")
|
|
}
|
|
|
|
var metaData *fleet.PaginationMetadata
|
|
if opt.IncludeMetadata {
|
|
metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0}
|
|
if len(rows) > int(opt.PerPage) { //nolint:gosec // dismiss G115
|
|
metaData.HasNextResults = true
|
|
rows = rows[:len(rows)-1]
|
|
}
|
|
}
|
|
|
|
results := make([]*fleet.HostScriptDetail, 0, len(rows))
|
|
for _, r := range rows {
|
|
results = append(results, fleet.NewHostScriptDetail(hostID, r.ScriptID, r.Name, r.ExecutionID, r.ExecutedAt, r.ExitCode, r.HSRID))
|
|
}
|
|
|
|
return results, metaData, nil
|
|
}
|
|
|
|
func (ds *Datastore) BatchSetScripts(ctx context.Context, tmID *uint, scripts []*fleet.Script) ([]fleet.ScriptResponse, error) {
|
|
const loadExistingScripts = `
|
|
SELECT
|
|
name
|
|
FROM
|
|
scripts
|
|
WHERE
|
|
global_or_team_id = ? AND
|
|
name IN (?)
|
|
`
|
|
const deleteAllScriptsInTeam = `
|
|
DELETE FROM
|
|
scripts
|
|
WHERE
|
|
global_or_team_id = ?
|
|
`
|
|
const unsetAllScriptsFromPolicies = `UPDATE policies SET script_id = NULL WHERE team_id = ?`
|
|
|
|
const clearAllPendingExecutionsHSR = `DELETE FROM host_script_results WHERE
|
|
exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)`
|
|
|
|
const loadAffectedHostsAllPendingExecutionsUA = `
|
|
SELECT
|
|
DISTINCT host_id
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.activity_type = 'script'
|
|
AND ua.activated_at IS NOT NULL
|
|
AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)`
|
|
|
|
const clearAllPendingExecutionsUA = `DELETE FROM upcoming_activities
|
|
USING
|
|
upcoming_activities
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON upcoming_activities.id = sua.upcoming_activity_id
|
|
WHERE
|
|
upcoming_activities.activity_type = 'script'
|
|
AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)`
|
|
|
|
const unsetScriptsNotInListFromPolicies = `
|
|
UPDATE policies SET script_id = NULL
|
|
WHERE script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))
|
|
`
|
|
|
|
const deleteScriptsNotInList = `
|
|
DELETE FROM
|
|
scripts
|
|
WHERE
|
|
global_or_team_id = ? AND
|
|
name NOT IN (?)
|
|
`
|
|
|
|
const clearPendingExecutionsNotInListHSR = `DELETE FROM host_script_results WHERE
|
|
exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))`
|
|
|
|
const loadAffectedHostsPendingExecutionsNotInListUA = `
|
|
SELECT
|
|
DISTINCT host_id
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.activity_type = 'script'
|
|
AND ua.activated_at IS NOT NULL
|
|
AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))`
|
|
|
|
const clearPendingExecutionsNotInListUA = `DELETE FROM upcoming_activities
|
|
USING
|
|
upcoming_activities
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON upcoming_activities.id = sua.upcoming_activity_id
|
|
WHERE
|
|
upcoming_activities.activity_type = 'script'
|
|
AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))`
|
|
|
|
const insertNewOrEditedScript = `
|
|
INSERT INTO
|
|
scripts (
|
|
team_id, global_or_team_id, name, script_content_id
|
|
)
|
|
VALUES
|
|
(?, ?, ?, ?)
|
|
ON DUPLICATE KEY UPDATE
|
|
script_content_id = VALUES(script_content_id), id=LAST_INSERT_ID(id)
|
|
`
|
|
|
|
const clearPendingExecutionsWithObsoleteScriptHSR = `DELETE FROM host_script_results WHERE
|
|
exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND script_id = ? AND script_content_id != ?`
|
|
|
|
const loadAffectedHostsPendingExecutionsWithObsoleteScriptUA = `
|
|
SELECT
|
|
DISTINCT host_id
|
|
FROM
|
|
upcoming_activities ua
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON ua.id = sua.upcoming_activity_id
|
|
WHERE
|
|
ua.activity_type = 'script'
|
|
AND ua.activated_at IS NOT NULL
|
|
AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND sua.script_id = ? AND sua.script_content_id != ?`
|
|
|
|
const clearPendingExecutionsWithObsoleteScriptUA = `DELETE FROM upcoming_activities
|
|
USING
|
|
upcoming_activities
|
|
INNER JOIN script_upcoming_activities sua
|
|
ON upcoming_activities.id = sua.upcoming_activity_id
|
|
WHERE
|
|
upcoming_activities.activity_type = 'script'
|
|
AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND)
|
|
AND sua.script_id = ? AND sua.script_content_id != ?`
|
|
|
|
const loadInsertedScripts = `SELECT id, team_id, name FROM scripts WHERE global_or_team_id = ?`
|
|
|
|
// use a team id of 0 if no-team
|
|
var globalOrTeamID uint
|
|
if tmID != nil {
|
|
globalOrTeamID = *tmID
|
|
}
|
|
|
|
// build a list of names for the incoming scripts, will keep the
|
|
// existing ones if there's a match and no change
|
|
incomingNames := make([]string, len(scripts))
|
|
// at the same time, index the incoming scripts keyed by name for ease
|
|
// of processing
|
|
incomingScripts := make(map[string]*fleet.Script, len(scripts))
|
|
for i, p := range scripts {
|
|
incomingNames[i] = p.Name
|
|
incomingScripts[p.Name] = p
|
|
}
|
|
|
|
var insertedScripts []fleet.ScriptResponse
|
|
var activateAffectedHosts []uint
|
|
|
|
if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var existingScripts []*fleet.Script
|
|
|
|
if len(incomingNames) > 0 {
|
|
// load existing scripts that match the incoming scripts by names
|
|
stmt, args, err := sqlx.In(loadExistingScripts, globalOrTeamID, incomingNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build query to load existing scripts")
|
|
}
|
|
if err := sqlx.SelectContext(ctx, tx, &existingScripts, stmt, args...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load existing scripts")
|
|
}
|
|
}
|
|
|
|
// figure out if we need to delete any scripts
|
|
keepNames := make([]string, 0, len(incomingNames))
|
|
for _, p := range existingScripts {
|
|
if newS := incomingScripts[p.Name]; newS != nil {
|
|
keepNames = append(keepNames, p.Name)
|
|
}
|
|
}
|
|
|
|
var (
|
|
scriptsStmt string
|
|
scriptsArgs []any
|
|
policiesStmt string
|
|
policiesArgs []any
|
|
executionsStmt string
|
|
executionsArgs []any
|
|
extraExecStmt string
|
|
extraExecArgs []any
|
|
err error
|
|
affectedHostIDs []uint
|
|
)
|
|
if len(keepNames) > 0 {
|
|
// delete the obsolete scripts
|
|
scriptsStmt, scriptsArgs, err = sqlx.In(deleteScriptsNotInList, globalOrTeamID, keepNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build statement to delete obsolete scripts")
|
|
}
|
|
|
|
policiesStmt, policiesArgs, err = sqlx.In(unsetScriptsNotInListFromPolicies, globalOrTeamID, keepNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build statement to unset obsolete scripts from policies")
|
|
}
|
|
|
|
executionsStmt, executionsArgs, err = sqlx.In(clearPendingExecutionsNotInListHSR, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build statement to clear pending script executions from obsolete scripts")
|
|
}
|
|
|
|
loadAffectedStmt, args, err := sqlx.In(loadAffectedHostsPendingExecutionsNotInListUA,
|
|
int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build query to load affected hosts for upcoming script executions")
|
|
}
|
|
if err := sqlx.SelectContext(ctx, tx, &affectedHostIDs, loadAffectedStmt, args...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load affected hosts for upcoming script executions")
|
|
}
|
|
|
|
extraExecStmt, extraExecArgs, err = sqlx.In(clearPendingExecutionsNotInListUA, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "build statement to clear upcoming pending script executions from obsolete scripts")
|
|
}
|
|
} else {
|
|
scriptsStmt = deleteAllScriptsInTeam
|
|
scriptsArgs = []any{globalOrTeamID}
|
|
|
|
policiesStmt = unsetAllScriptsFromPolicies
|
|
policiesArgs = []any{globalOrTeamID}
|
|
|
|
executionsStmt = clearAllPendingExecutionsHSR
|
|
executionsArgs = []any{int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID}
|
|
|
|
if err := sqlx.SelectContext(ctx, tx, &affectedHostIDs,
|
|
loadAffectedHostsAllPendingExecutionsUA, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load affected hosts for upcoming script executions")
|
|
}
|
|
|
|
extraExecStmt = clearAllPendingExecutionsUA
|
|
extraExecArgs = []any{int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID}
|
|
}
|
|
if _, err := tx.ExecContext(ctx, policiesStmt, policiesArgs...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "unset obsolete scripts from policies")
|
|
}
|
|
if _, err := tx.ExecContext(ctx, executionsStmt, executionsArgs...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "clear obsolete script pending executions")
|
|
}
|
|
if _, err := tx.ExecContext(ctx, extraExecStmt, extraExecArgs...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "clear obsolete upcoming script pending executions")
|
|
}
|
|
if _, err := tx.ExecContext(ctx, scriptsStmt, scriptsArgs...); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "delete obsolete scripts")
|
|
}
|
|
activateAffectedHosts = affectedHostIDs
|
|
|
|
// insert the new scripts and the ones that have changed
|
|
for _, s := range incomingScripts {
|
|
scRes, err := insertScriptContents(ctx, tx, s.ScriptContents)
|
|
if err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "inserting script contents for script with name %q", s.Name)
|
|
}
|
|
contentID, _ := scRes.LastInsertId()
|
|
insertRes, err := tx.ExecContext(ctx, insertNewOrEditedScript, tmID, globalOrTeamID, s.Name, uint(contentID)) //nolint:gosec // dismiss G115
|
|
if err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "insert new/edited script with name %q", s.Name)
|
|
}
|
|
scriptID, _ := insertRes.LastInsertId()
|
|
|
|
if _, err := tx.ExecContext(ctx, clearPendingExecutionsWithObsoleteScriptHSR, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "clear obsolete pending script executions with name %q", s.Name)
|
|
}
|
|
|
|
var affectedHosts []uint
|
|
if err := sqlx.SelectContext(ctx, tx, &affectedHosts, loadAffectedHostsPendingExecutionsWithObsoleteScriptUA,
|
|
int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "load affected hosts for upcoming script executions with name %q", s.Name)
|
|
}
|
|
activateAffectedHosts = append(activateAffectedHosts, affectedHosts...)
|
|
|
|
if _, err = tx.ExecContext(ctx, clearPendingExecutionsWithObsoleteScriptUA, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil {
|
|
return ctxerr.Wrapf(ctx, err, "clear obsolete upcoming pending script executions with name %q", s.Name)
|
|
}
|
|
}
|
|
|
|
if err := sqlx.SelectContext(ctx, tx, &insertedScripts, loadInsertedScripts, globalOrTeamID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "load inserted scripts")
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := ds.activateNextUpcomingActivityForBatchOfHosts(ctx, activateAffectedHosts); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "activate next upcoming activity for batch of hosts")
|
|
}
|
|
|
|
return insertedScripts, nil
|
|
}
|
|
|
|
type hostMDMActions struct {
|
|
LockRef *string `db:"lock_ref"`
|
|
WipeRef *string `db:"wipe_ref"`
|
|
UnlockRef *string `db:"unlock_ref"`
|
|
UnlockPIN *string `db:"unlock_pin"`
|
|
FleetPlatform string `db:"fleet_platform"`
|
|
}
|
|
|
|
func (ds *Datastore) GetHostLockWipeStatus(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) {
|
|
const stmt = `
|
|
SELECT
|
|
lock_ref,
|
|
wipe_ref,
|
|
unlock_ref,
|
|
unlock_pin,
|
|
fleet_platform
|
|
FROM
|
|
host_mdm_actions
|
|
WHERE
|
|
host_id = ?
|
|
`
|
|
var mdmActions hostMDMActions
|
|
hostPlatform := host.FleetPlatform()
|
|
status := &fleet.HostLockWipeStatus{
|
|
HostFleetPlatform: hostPlatform,
|
|
}
|
|
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &mdmActions, stmt, host.ID); err != nil {
|
|
if err == sql.ErrNoRows {
|
|
// do not return a Not Found error, return the zero-value status, which
|
|
// will report the correct states.
|
|
return status, nil
|
|
}
|
|
return nil, ctxerr.Wrap(ctx, err, "get host lock/wipe status")
|
|
}
|
|
|
|
// if we have a fleet platform stored in host_mdm_actions, use it instead of
|
|
// the host.FleetPlatform() because the platform can be overwritten with an
|
|
// unknown OS name when a Wipe gets executed.
|
|
if mdmActions.FleetPlatform != "" {
|
|
hostPlatform = mdmActions.FleetPlatform
|
|
status.HostFleetPlatform = hostPlatform
|
|
}
|
|
|
|
switch hostPlatform {
|
|
case "darwin", "ios", "ipados":
|
|
if mdmActions.UnlockPIN != nil && hostPlatform == "darwin" {
|
|
// Unlock PIN is only available for macOS hosts
|
|
status.UnlockPIN = *mdmActions.UnlockPIN
|
|
}
|
|
if mdmActions.UnlockRef != nil && hostPlatform == "darwin" {
|
|
// the unlock reference is a timestamp
|
|
// (we only store the timestamp for macOS unlocks)
|
|
var err error
|
|
status.UnlockRequestedAt, err = time.Parse(time.DateTime, *mdmActions.UnlockRef)
|
|
if err != nil {
|
|
// if the format is unexpected but there's something in UnlockRef, just
|
|
// replace it with the current timestamp, it should still indicate that
|
|
// an unlock was requested (e.g. in case someone plays with the data
|
|
// directly in the DB and messes up the format).
|
|
status.UnlockRequestedAt = time.Now().UTC()
|
|
}
|
|
} else if mdmActions.UnlockRef != nil && hostPlatform != "darwin" {
|
|
// the unlock reference is an MDM command uuid
|
|
cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.UnlockRef, host.UUID)
|
|
if err != nil && !fleet.IsNotFound(err) {
|
|
return nil, ctxerr.Wrap(ctx, err, "get unlock reference")
|
|
}
|
|
if fleet.IsNotFound(err) {
|
|
ds.logger.ErrorContext(ctx, "orphan unlock MDM command reference", "host_id", host.ID, "command_uuid", *mdmActions.UnlockRef)
|
|
}
|
|
status.UnlockMDMCommand = cmd
|
|
status.UnlockMDMCommandResult = cmdRes
|
|
}
|
|
|
|
if mdmActions.LockRef != nil {
|
|
// the lock reference is an MDM command
|
|
cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.LockRef, host.UUID)
|
|
if err != nil && !fleet.IsNotFound(err) {
|
|
return nil, ctxerr.Wrap(ctx, err, "get lock reference")
|
|
}
|
|
if fleet.IsNotFound(err) {
|
|
ds.logger.ErrorContext(ctx, "orphan lock MDM command reference", "host_id", host.ID, "command_uuid", *mdmActions.LockRef)
|
|
}
|
|
|
|
status.LockMDMCommand = cmd
|
|
status.LockMDMCommandResult = cmdRes
|
|
|
|
// for ADE enrolled iDevices, we don't advance to "locked" until we have location data
|
|
if status.LockMDMCommand != nil && (hostPlatform == "ios" || hostPlatform == "ipados") {
|
|
_, err = ds.GetHostLocationData(ctx, host.ID)
|
|
switch {
|
|
case fleet.IsNotFound(err):
|
|
status.LocationPending = true
|
|
case err != nil:
|
|
return nil, err
|
|
}
|
|
}
|
|
}
|
|
|
|
if mdmActions.WipeRef != nil {
|
|
// the wipe reference is an MDM command
|
|
cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.WipeRef, host.UUID)
|
|
if err != nil && !fleet.IsNotFound(err) {
|
|
return nil, ctxerr.Wrap(ctx, err, "get wipe reference")
|
|
}
|
|
if fleet.IsNotFound(err) {
|
|
ds.logger.ErrorContext(ctx, "orphan wipe MDM command reference", "host_id", host.ID, "command_uuid", *mdmActions.WipeRef)
|
|
}
|
|
status.WipeMDMCommand = cmd
|
|
status.WipeMDMCommandResult = cmdRes
|
|
}
|
|
|
|
case "windows", "linux":
|
|
// lock and unlock references are scripts
|
|
if mdmActions.LockRef != nil {
|
|
hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.LockRef, scriptExecutionSearchOpts{IncludeCanceled: true})
|
|
if err != nil && !fleet.IsNotFound(err) {
|
|
return nil, ctxerr.Wrap(ctx, err, "get lock reference script result")
|
|
}
|
|
if fleet.IsNotFound(err) {
|
|
ds.logger.ErrorContext(ctx, "orphan lock script execution reference", "host_id", host.ID, "execution_id", *mdmActions.LockRef)
|
|
}
|
|
status.LockScript = hsr
|
|
}
|
|
|
|
if mdmActions.UnlockRef != nil {
|
|
hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.UnlockRef, scriptExecutionSearchOpts{IncludeCanceled: true})
|
|
if err != nil && !fleet.IsNotFound(err) {
|
|
return nil, ctxerr.Wrap(ctx, err, "get unlock reference script result")
|
|
}
|
|
if fleet.IsNotFound(err) {
|
|
ds.logger.ErrorContext(ctx, "orphan unlock script execution reference", "host_id", host.ID, "execution_id", *mdmActions.UnlockRef)
|
|
}
|
|
status.UnlockScript = hsr
|
|
}
|
|
|
|
// wipe is an MDM command on Windows, a script on Linux
|
|
if mdmActions.WipeRef != nil {
|
|
if hostPlatform == "windows" {
|
|
cmd, cmdRes, err := ds.getHostMDMWindowsCommand(ctx, *mdmActions.WipeRef, host.UUID)
|
|
if err != nil && !fleet.IsNotFound(err) {
|
|
return nil, ctxerr.Wrap(ctx, err, "get wipe reference")
|
|
}
|
|
if fleet.IsNotFound(err) {
|
|
ds.logger.ErrorContext(ctx, "orphan wipe MDM command reference", "host_id", host.ID, "command_uuid", *mdmActions.WipeRef)
|
|
}
|
|
status.WipeMDMCommand = cmd
|
|
status.WipeMDMCommandResult = cmdRes
|
|
} else {
|
|
hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.WipeRef, scriptExecutionSearchOpts{IncludeCanceled: true})
|
|
if err != nil && !fleet.IsNotFound(err) {
|
|
return nil, ctxerr.Wrap(ctx, err, "get wipe reference script result")
|
|
}
|
|
if fleet.IsNotFound(err) {
|
|
ds.logger.ErrorContext(ctx, "orphan wipe script execution reference", "host_id", host.ID, "execution_id", *mdmActions.WipeRef)
|
|
}
|
|
status.WipeScript = hsr
|
|
}
|
|
}
|
|
}
|
|
return status, nil
|
|
}
|
|
|
|
// GetHostsLockWipeStatusBatch gets the lock/unlock and wipe status for multiple hosts efficiently.
|
|
func (ds *Datastore) GetHostsLockWipeStatusBatch(ctx context.Context, hosts []*fleet.Host) (map[uint]*fleet.HostLockWipeStatus, error) {
|
|
if len(hosts) == 0 {
|
|
return make(map[uint]*fleet.HostLockWipeStatus), nil
|
|
}
|
|
|
|
// Build list of host IDs for queries
|
|
hostIDs := make([]uint, 0, len(hosts))
|
|
for _, host := range hosts {
|
|
hostIDs = append(hostIDs, host.ID)
|
|
}
|
|
|
|
// Query all host_mdm_actions for these hosts
|
|
stmt := `
|
|
SELECT
|
|
host_id,
|
|
lock_ref,
|
|
wipe_ref,
|
|
unlock_ref,
|
|
unlock_pin,
|
|
fleet_platform
|
|
FROM
|
|
host_mdm_actions
|
|
WHERE
|
|
host_id IN (?)
|
|
`
|
|
query, args, err := sqlx.In(stmt, hostIDs)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "build IN query for host_mdm_actions")
|
|
}
|
|
|
|
var mdmActionsRows []struct {
|
|
HostID uint `db:"host_id"`
|
|
hostMDMActions
|
|
}
|
|
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &mdmActionsRows, query, args...); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "select host_mdm_actions batch")
|
|
}
|
|
|
|
// Collect all command/script UUIDs that need to be queried, organized by type and platform
|
|
type refKey struct {
|
|
uuid string
|
|
hostUUID string
|
|
hostID uint
|
|
refType string // "lock", "unlock", "wipe"
|
|
}
|
|
|
|
appleCommandRefs := make([]refKey, 0)
|
|
windowsCommandRefs := make([]refKey, 0)
|
|
scriptRefs := make([]refKey, 0)
|
|
|
|
// Build initial status map with platform info
|
|
statusMap := make(map[uint]*fleet.HostLockWipeStatus, len(hosts))
|
|
mdmActionsMap := make(map[uint]*hostMDMActions)
|
|
|
|
for _, row := range mdmActionsRows {
|
|
mdmActionsMap[row.HostID] = &hostMDMActions{
|
|
LockRef: row.LockRef,
|
|
WipeRef: row.WipeRef,
|
|
UnlockRef: row.UnlockRef,
|
|
UnlockPIN: row.UnlockPIN,
|
|
FleetPlatform: row.FleetPlatform,
|
|
}
|
|
}
|
|
|
|
// Initialize status for all hosts and collect refs to query
|
|
for _, host := range hosts {
|
|
fleetPlatform := host.FleetPlatform()
|
|
status := &fleet.HostLockWipeStatus{
|
|
HostFleetPlatform: fleetPlatform,
|
|
}
|
|
|
|
mdmActions, hasMDMActions := mdmActionsMap[host.ID]
|
|
statusMap[host.ID] = status
|
|
if !hasMDMActions {
|
|
continue
|
|
}
|
|
|
|
// Use stored platform if available
|
|
if mdmActions.FleetPlatform != "" {
|
|
fleetPlatform = mdmActions.FleetPlatform
|
|
status.HostFleetPlatform = fleetPlatform
|
|
}
|
|
|
|
// Handle macOS unlock PIN (darwin only)
|
|
if mdmActions.UnlockPIN != nil && fleetPlatform == "darwin" {
|
|
status.UnlockPIN = *mdmActions.UnlockPIN
|
|
}
|
|
|
|
// Collect command/script references based on platform
|
|
switch fleetPlatform {
|
|
case "darwin", "ios", "ipados":
|
|
// Apple platforms use MDM commands for lock, unlock (ios/ipados only), and wipe
|
|
if mdmActions.LockRef != nil {
|
|
appleCommandRefs = append(appleCommandRefs, refKey{
|
|
uuid: *mdmActions.LockRef,
|
|
hostUUID: host.UUID,
|
|
hostID: host.ID,
|
|
refType: "lock",
|
|
})
|
|
}
|
|
if mdmActions.UnlockRef != nil && fleetPlatform != "darwin" {
|
|
// iOS/iPadOS use MDM command for unlock, darwin uses timestamp
|
|
appleCommandRefs = append(appleCommandRefs, refKey{
|
|
uuid: *mdmActions.UnlockRef,
|
|
hostUUID: host.UUID,
|
|
hostID: host.ID,
|
|
refType: "unlock",
|
|
})
|
|
} else if mdmActions.UnlockRef != nil && fleetPlatform == "darwin" {
|
|
// For macOS, unlock_ref is a timestamp, parse it here
|
|
unlockTime, err := time.Parse(time.DateTime, *mdmActions.UnlockRef)
|
|
if err != nil {
|
|
// Use current time if format is unexpected
|
|
unlockTime = time.Now().UTC()
|
|
}
|
|
status.UnlockRequestedAt = unlockTime
|
|
}
|
|
if mdmActions.WipeRef != nil {
|
|
appleCommandRefs = append(appleCommandRefs, refKey{
|
|
uuid: *mdmActions.WipeRef,
|
|
hostUUID: host.UUID,
|
|
hostID: host.ID,
|
|
refType: "wipe",
|
|
})
|
|
}
|
|
|
|
case "windows":
|
|
// Windows uses scripts for lock/unlock, MDM command for wipe
|
|
if mdmActions.LockRef != nil {
|
|
scriptRefs = append(scriptRefs, refKey{
|
|
uuid: *mdmActions.LockRef,
|
|
hostID: host.ID,
|
|
refType: "lock",
|
|
})
|
|
}
|
|
if mdmActions.UnlockRef != nil {
|
|
scriptRefs = append(scriptRefs, refKey{
|
|
uuid: *mdmActions.UnlockRef,
|
|
hostID: host.ID,
|
|
refType: "unlock",
|
|
})
|
|
}
|
|
if mdmActions.WipeRef != nil {
|
|
windowsCommandRefs = append(windowsCommandRefs, refKey{
|
|
uuid: *mdmActions.WipeRef,
|
|
hostUUID: host.UUID,
|
|
hostID: host.ID,
|
|
refType: "wipe",
|
|
})
|
|
}
|
|
|
|
case "linux":
|
|
// Linux uses scripts for lock, unlock, and wipe
|
|
if mdmActions.LockRef != nil {
|
|
scriptRefs = append(scriptRefs, refKey{
|
|
uuid: *mdmActions.LockRef,
|
|
hostID: host.ID,
|
|
refType: "lock",
|
|
})
|
|
}
|
|
if mdmActions.UnlockRef != nil {
|
|
scriptRefs = append(scriptRefs, refKey{
|
|
uuid: *mdmActions.UnlockRef,
|
|
hostID: host.ID,
|
|
refType: "unlock",
|
|
})
|
|
}
|
|
if mdmActions.WipeRef != nil {
|
|
scriptRefs = append(scriptRefs, refKey{
|
|
uuid: *mdmActions.WipeRef,
|
|
hostID: host.ID,
|
|
refType: "wipe",
|
|
})
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
// Batch query Apple MDM commands
|
|
if len(appleCommandRefs) > 0 {
|
|
cmdUUIDs := make([]string, 0, len(appleCommandRefs))
|
|
cmdUUIDMap := make(map[string][]refKey)
|
|
for _, ref := range appleCommandRefs {
|
|
if _, exists := cmdUUIDMap[ref.uuid]; !exists {
|
|
cmdUUIDs = append(cmdUUIDs, ref.uuid)
|
|
}
|
|
cmdUUIDMap[ref.uuid] = append(cmdUUIDMap[ref.uuid], ref)
|
|
}
|
|
|
|
// Query commands
|
|
cmdStmt := `SELECT command_uuid, request_type FROM nano_commands WHERE command_uuid IN (?)`
|
|
cmdQuery, cmdArgs, err := sqlx.In(cmdStmt, cmdUUIDs)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "build IN query for apple commands")
|
|
}
|
|
|
|
var commands []struct {
|
|
CommandUUID string `db:"command_uuid"`
|
|
RequestType string `db:"request_type"`
|
|
}
|
|
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &commands, cmdQuery, cmdArgs...); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "select apple mdm commands batch")
|
|
}
|
|
|
|
commandMap := make(map[string]*fleet.MDMCommand)
|
|
for _, cmd := range commands {
|
|
commandMap[cmd.CommandUUID] = &fleet.MDMCommand{
|
|
CommandUUID: cmd.CommandUUID,
|
|
RequestType: cmd.RequestType,
|
|
}
|
|
}
|
|
|
|
// Query command results
|
|
resultStmt := `SELECT command_uuid, id, status FROM nano_command_results WHERE command_uuid IN (?)`
|
|
resultQuery, resultArgs, err := sqlx.In(resultStmt, cmdUUIDs)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "build IN query for apple command results")
|
|
}
|
|
|
|
var results []struct {
|
|
CommandUUID string `db:"command_uuid"`
|
|
ID string `db:"id"`
|
|
Status string `db:"status"`
|
|
}
|
|
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, resultQuery, resultArgs...); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "select apple mdm command results batch")
|
|
}
|
|
|
|
// Build map of command_uuid -> host_uuid -> result
|
|
resultMap := make(map[string]map[string]*fleet.MDMCommandResult)
|
|
for _, res := range results {
|
|
if resultMap[res.CommandUUID] == nil {
|
|
resultMap[res.CommandUUID] = make(map[string]*fleet.MDMCommandResult)
|
|
}
|
|
// Only keep terminal statuses
|
|
if res.Status == fleet.MDMAppleStatusAcknowledged || res.Status == fleet.MDMAppleStatusError || res.Status == fleet.MDMAppleStatusCommandFormatError {
|
|
resultMap[res.CommandUUID][res.ID] = &fleet.MDMCommandResult{
|
|
CommandUUID: res.CommandUUID,
|
|
Status: res.Status,
|
|
HostUUID: res.ID,
|
|
}
|
|
}
|
|
}
|
|
|
|
// Assign commands and results to status objects
|
|
for cmdUUID, refs := range cmdUUIDMap {
|
|
cmd := commandMap[cmdUUID]
|
|
for _, ref := range refs {
|
|
status := statusMap[ref.hostID]
|
|
var cmdRes *fleet.MDMCommandResult
|
|
if resultMap[cmdUUID] != nil {
|
|
cmdRes = resultMap[cmdUUID][ref.hostUUID]
|
|
}
|
|
|
|
switch ref.refType {
|
|
case "lock":
|
|
status.LockMDMCommand = cmd
|
|
status.LockMDMCommandResult = cmdRes
|
|
case "unlock":
|
|
status.UnlockMDMCommand = cmd
|
|
status.UnlockMDMCommandResult = cmdRes
|
|
case "wipe":
|
|
status.WipeMDMCommand = cmd
|
|
status.WipeMDMCommandResult = cmdRes
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Batch query Windows MDM commands
|
|
if len(windowsCommandRefs) > 0 {
|
|
cmdUUIDs := make([]string, 0, len(windowsCommandRefs))
|
|
cmdUUIDMap := make(map[string][]refKey)
|
|
for _, ref := range windowsCommandRefs {
|
|
if _, exists := cmdUUIDMap[ref.uuid]; !exists {
|
|
cmdUUIDs = append(cmdUUIDs, ref.uuid)
|
|
}
|
|
cmdUUIDMap[ref.uuid] = append(cmdUUIDMap[ref.uuid], ref)
|
|
}
|
|
|
|
// Query commands
|
|
cmdStmt := `SELECT command_uuid, target_loc_uri FROM windows_mdm_commands WHERE command_uuid IN (?)`
|
|
cmdQuery, cmdArgs, err := sqlx.In(cmdStmt, cmdUUIDs)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "build IN query for windows commands")
|
|
}
|
|
|
|
var commands []struct {
|
|
CommandUUID string `db:"command_uuid"`
|
|
TargetLocURI string `db:"target_loc_uri"`
|
|
}
|
|
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &commands, cmdQuery, cmdArgs...); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "select windows mdm commands batch")
|
|
}
|
|
|
|
commandMap := make(map[string]*fleet.MDMCommand)
|
|
for _, cmd := range commands {
|
|
commandMap[cmd.CommandUUID] = &fleet.MDMCommand{
|
|
CommandUUID: cmd.CommandUUID,
|
|
RequestType: cmd.TargetLocURI,
|
|
}
|
|
}
|
|
|
|
// Query command results - JOIN with enrollments to get host_uuid
|
|
resultStmt := `
|
|
SELECT
|
|
wcr.command_uuid,
|
|
we.host_uuid,
|
|
wcr.status_code
|
|
FROM
|
|
windows_mdm_command_results wcr
|
|
INNER JOIN mdm_windows_enrollments we ON wcr.enrollment_id = we.id
|
|
WHERE
|
|
wcr.command_uuid IN (?)`
|
|
resultQuery, resultArgs, err := sqlx.In(resultStmt, cmdUUIDs)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "build IN query for windows command results")
|
|
}
|
|
|
|
var results []struct {
|
|
CommandUUID string `db:"command_uuid"`
|
|
HostUUID string `db:"host_uuid"`
|
|
StatusCode string `db:"status_code"`
|
|
}
|
|
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, resultQuery, resultArgs...); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "select windows mdm command results batch")
|
|
}
|
|
|
|
// Build map of command_uuid -> host_uuid -> result
|
|
resultMap := make(map[string]map[string]*fleet.MDMCommandResult)
|
|
for _, res := range results {
|
|
if resultMap[res.CommandUUID] == nil {
|
|
resultMap[res.CommandUUID] = make(map[string]*fleet.MDMCommandResult)
|
|
}
|
|
resultMap[res.CommandUUID][res.HostUUID] = &fleet.MDMCommandResult{
|
|
CommandUUID: res.CommandUUID,
|
|
Status: res.StatusCode,
|
|
HostUUID: res.HostUUID,
|
|
}
|
|
}
|
|
|
|
// Assign commands and results to status objects
|
|
for cmdUUID, refs := range cmdUUIDMap {
|
|
cmd := commandMap[cmdUUID]
|
|
for _, ref := range refs {
|
|
status := statusMap[ref.hostID]
|
|
var cmdRes *fleet.MDMCommandResult
|
|
if resultMap[cmdUUID] != nil {
|
|
cmdRes = resultMap[cmdUUID][ref.hostUUID]
|
|
}
|
|
|
|
status.WipeMDMCommand = cmd
|
|
status.WipeMDMCommandResult = cmdRes
|
|
}
|
|
}
|
|
}
|
|
|
|
// Batch query script results
|
|
if len(scriptRefs) > 0 {
|
|
execIDs := make([]string, 0, len(scriptRefs))
|
|
execIDMap := make(map[string]refKey)
|
|
for _, ref := range scriptRefs {
|
|
execIDs = append(execIDs, ref.uuid)
|
|
execIDMap[ref.uuid] = ref
|
|
}
|
|
|
|
scriptStmt := `
|
|
SELECT
|
|
execution_id,
|
|
exit_code,
|
|
canceled
|
|
FROM
|
|
host_script_results
|
|
WHERE
|
|
execution_id IN (?)
|
|
`
|
|
scriptQuery, scriptArgs, err := sqlx.In(scriptStmt, execIDs)
|
|
if err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "build IN query for script results")
|
|
}
|
|
|
|
var scriptResults []struct {
|
|
ExecutionID string `db:"execution_id"`
|
|
ExitCode *int64 `db:"exit_code"`
|
|
Canceled bool `db:"canceled"`
|
|
}
|
|
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &scriptResults, scriptQuery, scriptArgs...); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "select script results batch")
|
|
}
|
|
|
|
scriptResultMap := make(map[string]*fleet.HostScriptResult)
|
|
for _, sr := range scriptResults {
|
|
scriptResultMap[sr.ExecutionID] = &fleet.HostScriptResult{
|
|
ExecutionID: sr.ExecutionID,
|
|
ExitCode: sr.ExitCode,
|
|
Canceled: sr.Canceled,
|
|
}
|
|
}
|
|
|
|
// Assign script results to status objects
|
|
for execID, ref := range execIDMap {
|
|
status := statusMap[ref.hostID]
|
|
scriptResult := scriptResultMap[execID]
|
|
|
|
switch ref.refType {
|
|
case "lock":
|
|
status.LockScript = scriptResult
|
|
case "unlock":
|
|
status.UnlockScript = scriptResult
|
|
case "wipe":
|
|
status.WipeScript = scriptResult
|
|
}
|
|
}
|
|
}
|
|
|
|
return statusMap, nil
|
|
}
|
|
|
|
func (ds *Datastore) getHostMDMWindowsCommand(ctx context.Context, cmdUUID, hostUUID string) (*fleet.MDMCommand, *fleet.MDMCommandResult, error) {
|
|
cmd, err := ds.getMDMCommand(ctx, ds.reader(ctx), cmdUUID)
|
|
if err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get Windows MDM command")
|
|
}
|
|
|
|
// get the MDM command result, which may be not found (indicating the command doesn't exist).
|
|
// If it is pending, then it returns 101, and result will be empty.
|
|
cmdResults, err := ds.GetMDMWindowsCommandResults(ctx, cmdUUID, "")
|
|
if err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get Windows MDM command result")
|
|
}
|
|
|
|
// each item in the slice returned by GetMDMWindowsCommandResults is
|
|
// potentially a result for a different host, we need to find the one for
|
|
// that specific host.
|
|
var cmdRes *fleet.MDMCommandResult
|
|
for _, r := range cmdResults {
|
|
if r.HostUUID != hostUUID {
|
|
continue
|
|
}
|
|
|
|
if r.Status == "101" || string(r.Result) == "" {
|
|
// command is still pending
|
|
continue
|
|
}
|
|
|
|
// all statuses for Windows indicate end of processing of the command
|
|
// (there is no equivalent of "NotNow" or "Idle" as for Apple).
|
|
cmdRes = r
|
|
break
|
|
}
|
|
return cmd, cmdRes, nil
|
|
}
|
|
|
|
func (ds *Datastore) getHostMDMAppleCommand(ctx context.Context, cmdUUID, hostUUID string) (*fleet.MDMCommand, *fleet.MDMCommandResult, error) {
|
|
cmd, err := ds.getMDMCommand(ctx, ds.reader(ctx), cmdUUID)
|
|
if err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get Apple MDM command")
|
|
}
|
|
|
|
// get the MDM command result, which may be not found (indicating the command
|
|
// is pending). Note that it doesn't return ErrNoRows if not found, it
|
|
// returns success and an empty cmdRes slice.
|
|
cmdResults, err := ds.GetMDMAppleCommandResults(ctx, cmdUUID, hostUUID)
|
|
if err != nil {
|
|
return nil, nil, ctxerr.Wrap(ctx, err, "get Apple MDM command result")
|
|
}
|
|
|
|
// filter by result status to preserve old behavior of this method where it doesn't return pending results.
|
|
var cmdRes *fleet.MDMCommandResult
|
|
for _, r := range cmdResults {
|
|
if r.HostUUID != hostUUID {
|
|
// this should never happen because we already filter by hostUUID, but just in case
|
|
continue
|
|
}
|
|
if r.Status == fleet.MDMAppleStatusAcknowledged || r.Status == fleet.MDMAppleStatusError || r.Status == fleet.MDMAppleStatusCommandFormatError {
|
|
cmdRes = r
|
|
break
|
|
}
|
|
}
|
|
return cmd, cmdRes, nil
|
|
}
|
|
|
|
// LockHostViaScript will create the script execution request and update
|
|
// host_mdm_actions in a single transaction.
|
|
func (ds *Datastore) LockHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error {
|
|
var res *fleet.HostScriptResult
|
|
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
|
|
scRes, err := insertScriptContents(ctx, tx, request.ScriptContents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, _ := scRes.LastInsertId()
|
|
request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115
|
|
|
|
res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "lock host via script create execution")
|
|
}
|
|
|
|
// on duplicate we don't clear any other existing state because at this
|
|
// point in time, this is just a request to lock the host that is recorded,
|
|
// it is pending execution. The host's state should be updated to "locked"
|
|
// only when the script execution is successfully completed, and then any
|
|
// unlock or wipe references should be cleared.
|
|
const stmt = `
|
|
INSERT INTO host_mdm_actions
|
|
(
|
|
host_id,
|
|
lock_ref,
|
|
fleet_platform
|
|
)
|
|
VALUES (?,?,?)
|
|
ON DUPLICATE KEY UPDATE
|
|
lock_ref = VALUES(lock_ref)
|
|
`
|
|
|
|
_, err = tx.ExecContext(ctx, stmt,
|
|
request.HostID,
|
|
res.ExecutionID,
|
|
hostFleetPlatform,
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "lock host via script update mdm actions")
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// UnlockHostViaScript will create the script execution request and update
|
|
// host_mdm_actions in a single transaction.
|
|
func (ds *Datastore) UnlockHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error {
|
|
var res *fleet.HostScriptResult
|
|
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
|
|
scRes, err := insertScriptContents(ctx, tx, request.ScriptContents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, _ := scRes.LastInsertId()
|
|
request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115
|
|
|
|
res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "unlock host via script create execution")
|
|
}
|
|
|
|
// on duplicate we don't clear any other existing state because at this
|
|
// point in time, this is just a request to unlock the host that is
|
|
// recorded, it is pending execution. The host's state should be updated to
|
|
// "unlocked" only when the script execution is successfully completed, and
|
|
// then any lock or wipe references should be cleared.
|
|
const stmt = `
|
|
INSERT INTO host_mdm_actions
|
|
(
|
|
host_id,
|
|
unlock_ref,
|
|
fleet_platform
|
|
)
|
|
VALUES (?,?,?)
|
|
ON DUPLICATE KEY UPDATE
|
|
unlock_ref = VALUES(unlock_ref),
|
|
unlock_pin = NULL
|
|
`
|
|
|
|
_, err = tx.ExecContext(ctx, stmt,
|
|
request.HostID,
|
|
res.ExecutionID,
|
|
hostFleetPlatform,
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "unlock host via script update mdm actions")
|
|
}
|
|
|
|
return err
|
|
})
|
|
}
|
|
|
|
// WipeHostViaScript creates the script execution request and updates the
|
|
// host_mdm_actions table in a single transaction.
|
|
func (ds *Datastore) WipeHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error {
|
|
var res *fleet.HostScriptResult
|
|
return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
var err error
|
|
|
|
scRes, err := insertScriptContents(ctx, tx, request.ScriptContents)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, _ := scRes.LastInsertId()
|
|
request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115
|
|
|
|
res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "wipe host via script create execution")
|
|
}
|
|
|
|
// on duplicate we don't clear any other existing state because at this
|
|
// point in time, this is just a request to wipe the host that is recorded,
|
|
// it is pending execution, so if it was locked, it is still locked (so the
|
|
// lock_ref info must still be there).
|
|
const stmt = `
|
|
INSERT INTO host_mdm_actions
|
|
(
|
|
host_id,
|
|
wipe_ref,
|
|
fleet_platform
|
|
)
|
|
VALUES (?,?,?)
|
|
ON DUPLICATE KEY UPDATE
|
|
wipe_ref = VALUES(wipe_ref)
|
|
`
|
|
|
|
_, err = tx.ExecContext(ctx, stmt,
|
|
request.HostID,
|
|
res.ExecutionID,
|
|
hostFleetPlatform,
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "wipe host via script update mdm actions")
|
|
}
|
|
|
|
return err
|
|
})
|
|
}
|
|
|
|
// UnlockHostManually records a manual unlock request for the given host.
|
|
// ts must be in UTC to ensure consistency with the STR_TO_DATE comparison in CleanAppleMDMLock.
|
|
func (ds *Datastore) UnlockHostManually(ctx context.Context, hostID uint, hostFleetPlatform string, ts time.Time) error {
|
|
const stmt = `
|
|
INSERT INTO host_mdm_actions
|
|
(
|
|
host_id,
|
|
unlock_ref,
|
|
fleet_platform
|
|
)
|
|
VALUES (?, ?, ?)
|
|
ON DUPLICATE KEY UPDATE
|
|
-- do not overwrite if a value is already set
|
|
unlock_ref = IF(unlock_ref IS NULL, VALUES(unlock_ref), unlock_ref)
|
|
`
|
|
// for macOS, the unlock_ref is just the timestamp at which the user first
|
|
// requested to unlock the host. This then indicates in the host's status
|
|
// that it's pending an unlock (which requires manual intervention by
|
|
// entering a PIN on the device). The /unlock endpoint can be called multiple
|
|
// times, so we record the timestamp of the first time it was requested and
|
|
// from then on, the host is marked as "pending unlock" until the device is
|
|
// actually unlocked with the PIN. The actual unlocking happens when the
|
|
// device sends an Idle MDM request.
|
|
unlockRef := ts.UTC().Format(time.DateTime)
|
|
_, err := ds.writer(ctx).ExecContext(ctx, stmt, hostID, unlockRef, hostFleetPlatform)
|
|
return ctxerr.Wrap(ctx, err, "record manual unlock host request")
|
|
}
|
|
|
|
func buildHostLockWipeStatusUpdateStmt(refCol string, succeeded bool, joinPart string, setUnlockRef bool) string {
|
|
var alias string
|
|
|
|
stmt := `UPDATE host_mdm_actions `
|
|
if joinPart != "" {
|
|
stmt += `hma ` + joinPart
|
|
alias = "hma."
|
|
}
|
|
stmt += ` SET `
|
|
|
|
if succeeded {
|
|
switch refCol {
|
|
case "lock_ref":
|
|
// Note that this must not clear the unlock_pin, because recording the
|
|
// lock request does generate the PIN and store it there to be used by an
|
|
// eventual unlock.
|
|
if !setUnlockRef {
|
|
stmt += fmt.Sprintf("%sunlock_ref = NULL, %[1]swipe_ref = NULL", alias)
|
|
} else {
|
|
// Currently only used for Apple MDM devices.
|
|
// We set the unlock_ref to current time since the device can be unlocked any time after the lock.
|
|
// Apple MDM does not have a concept of unlock pending.
|
|
// UTC_TIMESTAMP() is used to ensure timezone consistency with the comparison in CleanAppleMDMLock.
|
|
stmt += fmt.Sprintf("%sunlock_ref = UTC_TIMESTAMP(), %[1]swipe_ref = NULL", alias)
|
|
}
|
|
case "unlock_ref":
|
|
// a successful unlock clears itself as well as the lock ref, because
|
|
// unlock is the default state so we don't need to keep its unlock_ref
|
|
// around once it's confirmed.
|
|
stmt += fmt.Sprintf("%slock_ref = NULL, %[1]sunlock_ref = NULL, %[1]sunlock_pin = NULL, %[1]swipe_ref = NULL", alias)
|
|
case "wipe_ref":
|
|
stmt += fmt.Sprintf("%slock_ref = NULL, %[1]sunlock_ref = NULL, %[1]sunlock_pin = NULL", alias)
|
|
}
|
|
} else {
|
|
// if the action failed, then we clear the reference to that action itself so
|
|
// the host stays in the previous state (it doesn't transition to the new
|
|
// state).
|
|
stmt += fmt.Sprintf("%s"+refCol+" = NULL", alias)
|
|
}
|
|
return stmt
|
|
}
|
|
|
|
func (ds *Datastore) UpdateHostLockWipeStatusFromAppleMDMResult(ctx context.Context, hostUUID, cmdUUID, requestType string, succeeded bool) error {
|
|
// a bit of MDM protocol leaking in the mysql layer, but it's either that or
|
|
// the other way around (MDM protocol would translate to database column)
|
|
var refCol string
|
|
var setUnlockRef bool
|
|
switch requestType {
|
|
case "EraseDevice":
|
|
refCol = "wipe_ref"
|
|
case "DeviceLock":
|
|
refCol = "lock_ref"
|
|
setUnlockRef = true
|
|
case "EnableLostMode":
|
|
refCol = "lock_ref"
|
|
case "DisableLostMode":
|
|
refCol = "unlock_ref"
|
|
default:
|
|
return nil
|
|
}
|
|
return updateHostLockWipeStatusFromResultAndHostUUID(ctx, ds.writer(ctx), hostUUID, refCol, cmdUUID, succeeded, setUnlockRef)
|
|
}
|
|
|
|
func updateHostLockWipeStatusFromResultAndHostUUID(
|
|
ctx context.Context, tx sqlx.ExtContext, hostUUID, refCol, cmdUUID string, succeeded bool, setUnlockRef bool,
|
|
) error {
|
|
stmt := buildHostLockWipeStatusUpdateStmt(refCol, succeeded, `JOIN hosts h ON hma.host_id = h.id`, setUnlockRef)
|
|
stmt += ` WHERE h.uuid = ? AND hma.` + refCol + ` = ?`
|
|
_, err := tx.ExecContext(ctx, stmt, hostUUID, cmdUUID)
|
|
return ctxerr.Wrap(ctx, err, "update host lock/wipe status from result via host uuid")
|
|
}
|
|
|
|
func updateHostLockWipeStatusFromResult(ctx context.Context, tx sqlx.ExtContext, hostID uint, refCol string, succeeded bool) error {
|
|
stmt := buildHostLockWipeStatusUpdateStmt(refCol, succeeded, "", false)
|
|
stmt += ` WHERE host_id = ?`
|
|
_, err := tx.ExecContext(ctx, stmt, hostID)
|
|
return ctxerr.Wrap(ctx, err, "update host lock/wipe status from result")
|
|
}
|
|
|
|
func (ds *Datastore) updateUninstallStatusFromResult(ctx context.Context, tx sqlx.ExtContext, hostID uint, executionID string, exitCode int) error {
|
|
stmt := `
|
|
UPDATE host_software_installs SET uninstall_script_exit_code = ? WHERE execution_id = ? AND host_id = ?
|
|
`
|
|
if _, err := tx.ExecContext(ctx, stmt, exitCode, executionID, hostID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "update uninstall status from result")
|
|
}
|
|
// NOTE: no need to call activateNextUpcomingActivity here as this function
|
|
// is called from SetHostScriptExecutionResult which will call it before
|
|
// completing.
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) CleanupUnusedScriptContents(ctx context.Context) error {
|
|
deleteStmt := `
|
|
DELETE FROM
|
|
script_contents
|
|
WHERE
|
|
NOT EXISTS (
|
|
SELECT 1 FROM host_script_results WHERE script_content_id = script_contents.id)
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM scripts WHERE script_content_id = script_contents.id)
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM software_installers si
|
|
WHERE script_contents.id IN (si.install_script_content_id, si.post_install_script_content_id, si.uninstall_script_content_id)
|
|
)
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM setup_experience_scripts WHERE script_content_id = script_contents.id
|
|
)
|
|
AND NOT EXISTS (
|
|
SELECT 1 FROM script_upcoming_activities WHERE script_content_id = script_contents.id
|
|
)
|
|
`
|
|
_, err := ds.writer(ctx).ExecContext(ctx, deleteStmt)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "cleaning up unused script contents")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) getOrGenerateScriptContentsID(ctx context.Context, contents string) (uint, error) {
|
|
csum := md5ChecksumScriptContent(contents)
|
|
scriptContentsID, err := ds.optimisticGetOrInsert(ctx,
|
|
¶meterizedStmt{
|
|
Statement: `SELECT id FROM script_contents WHERE md5_checksum = UNHEX(?)`,
|
|
Args: []interface{}{csum},
|
|
},
|
|
¶meterizedStmt{
|
|
Statement: `INSERT INTO script_contents (md5_checksum, contents) VALUES (UNHEX(?), ?)`,
|
|
Args: []interface{}{csum, contents},
|
|
},
|
|
)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return scriptContentsID, nil
|
|
}
|
|
|
|
func teamIDEq(teamID1, teamID2 *uint) bool {
|
|
sameTeamNoTeam := teamID1 == nil && teamID2 == nil
|
|
sameTeamNumber := teamID1 != nil && teamID2 != nil && *teamID1 == *teamID2
|
|
return sameTeamNoTeam || sameTeamNumber
|
|
}
|
|
|
|
func (ds *Datastore) batchExecuteScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint, batchExecID string) error {
|
|
script, err := ds.Script(ctx, scriptID)
|
|
if err != nil {
|
|
return fleet.NewInvalidArgumentError("script_id", err.Error())
|
|
}
|
|
|
|
invalidHostIDPlatform := "batch-invalid-hostid"
|
|
|
|
// We need full host info to check if hosts are able to run scripts, see svc.RunHostScript
|
|
fullHosts := make([]*fleet.Host, 0, len(hostIDs))
|
|
|
|
// The execution results to be stored in the database
|
|
executions := make([]fleet.BatchExecutionHost, 0, len(fullHosts))
|
|
|
|
// Check that all hosts exist before attempting to process them
|
|
for _, hostID := range hostIDs {
|
|
host, err := ds.Host(ctx, hostID)
|
|
if err != nil {
|
|
fullHosts = append(fullHosts, &fleet.Host{
|
|
ID: hostID,
|
|
Platform: invalidHostIDPlatform,
|
|
})
|
|
continue
|
|
}
|
|
|
|
fullHosts = append(fullHosts, host)
|
|
}
|
|
|
|
if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
for _, host := range fullHosts {
|
|
// Host doesn't exist anymore
|
|
if host.Platform == invalidHostIDPlatform {
|
|
executions = append(executions, fleet.BatchExecutionHost{
|
|
HostID: host.ID,
|
|
Error: &fleet.BatchExecuteInvalidHost,
|
|
})
|
|
continue
|
|
}
|
|
|
|
// Non-orbit-enrolled host (iOS, android)
|
|
noNodeKey := host.OrbitNodeKey == nil || *host.OrbitNodeKey == ""
|
|
// Scripts disabled on host
|
|
scriptsDisabled := host.ScriptsEnabled != nil && !*host.ScriptsEnabled
|
|
|
|
if noNodeKey || scriptsDisabled {
|
|
executions = append(executions, fleet.BatchExecutionHost{
|
|
HostID: host.ID,
|
|
Error: &fleet.BatchExecuteIncompatibleFleetd,
|
|
})
|
|
continue
|
|
}
|
|
|
|
if !fleet.ValidateScriptPlatform(script.Name, host.Platform) {
|
|
executions = append(executions, fleet.BatchExecutionHost{
|
|
HostID: host.ID,
|
|
Error: &fleet.BatchExecuteIncompatiblePlatform,
|
|
})
|
|
continue
|
|
}
|
|
|
|
executionID, _, err := ds.insertNewHostScriptExecution(ctx, tx, &fleet.HostScriptRequestPayload{
|
|
HostID: host.ID,
|
|
UserID: userID,
|
|
ScriptID: &script.ID,
|
|
ScriptContentID: script.ScriptContentID,
|
|
}, false)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "queueing script for bulk execution")
|
|
}
|
|
|
|
executions = append(executions, fleet.BatchExecutionHost{
|
|
HostID: host.ID,
|
|
ExecutionID: &executionID,
|
|
})
|
|
}
|
|
|
|
_, err := tx.ExecContext(
|
|
ctx,
|
|
`INSERT INTO batch_activities (execution_id, script_id, status, activity_type, num_targeted, started_at) VALUES (?, ?, ?, ?, ?, NOW())
|
|
ON DUPLICATE KEY UPDATE status = VALUES(status), started_at = VALUES(started_at)`,
|
|
batchExecID,
|
|
script.ID,
|
|
fleet.ScheduledBatchExecutionStarted,
|
|
fleet.BatchExecutionActivityScript,
|
|
len(hostIDs),
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "failed to insert new batch execution")
|
|
}
|
|
|
|
args := make([]map[string]any, 0, len(executions))
|
|
for _, execHost := range executions {
|
|
args = append(args, map[string]any{
|
|
"batch_id": batchExecID,
|
|
"host_id": execHost.HostID,
|
|
"host_execution_id": execHost.ExecutionID,
|
|
"error": execHost.Error,
|
|
})
|
|
}
|
|
|
|
insertStmt := `
|
|
INSERT INTO batch_activity_host_results (
|
|
batch_execution_id,
|
|
host_id,
|
|
host_execution_id,
|
|
error
|
|
) VALUES (
|
|
:batch_id,
|
|
:host_id,
|
|
:host_execution_id,
|
|
:error
|
|
) ON DUPLICATE KEY UPDATE host_execution_id = VALUES(host_execution_id), error = VALUES(error)`
|
|
|
|
if _, err := sqlx.NamedExecContext(ctx, tx, insertStmt, args); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "associating script executions with batch job")
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return fmt.Errorf("creating bulk execution order: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) BatchExecuteScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint) (string, error) {
|
|
batchExecID := uuid.New().String()
|
|
|
|
script, err := ds.Script(ctx, scriptID)
|
|
if err != nil {
|
|
return "", fleet.NewInvalidArgumentError("script_id", err.Error())
|
|
}
|
|
|
|
for _, hostID := range hostIDs {
|
|
host, err := ds.HostLite(ctx, hostID)
|
|
if err != nil {
|
|
return "", fmt.Errorf("unable to load host information for %d: %w", hostID, err)
|
|
}
|
|
|
|
if !teamIDEq(host.TeamID, script.TeamID) {
|
|
return "", ctxerr.Errorf(ctx, "all hosts must be on the same fleet as the script")
|
|
}
|
|
}
|
|
|
|
if err := ds.batchExecuteScript(ctx, userID, scriptID, hostIDs, batchExecID); err != nil {
|
|
return "", ctxerr.Wrap(ctx, err, "immediate batch execution")
|
|
}
|
|
|
|
return batchExecID, nil
|
|
}
|
|
|
|
func (ds *Datastore) BatchScheduleScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint, notBefore time.Time) (string, error) {
|
|
batchExecID := uuid.New().String()
|
|
|
|
const batchActivitiesStmt = `INSERT INTO batch_activities (execution_id, job_id, script_id, user_id, status, activity_type, num_targeted) VALUES (?, ?, ?, ?, ?, ?, ?)`
|
|
const batchHostsStmt = `INSERT INTO batch_activity_host_results (batch_execution_id, host_id) VALUES (:exec_id, :host_id)`
|
|
|
|
argBytes, err := json.Marshal(fleet.BatchActivityScriptJobArgs{
|
|
ExecutionID: batchExecID,
|
|
})
|
|
if err != nil {
|
|
return "", ctxerr.Wrap(ctx, err, "encooding job args")
|
|
}
|
|
|
|
if err := ds.withTx(ctx, func(tx sqlx.ExtContext) error {
|
|
job, err := ds.NewJob(ctx, &fleet.Job{
|
|
Name: fleet.BatchActivityScriptsJobName,
|
|
Args: (*json.RawMessage)(&argBytes),
|
|
State: fleet.JobStateQueued,
|
|
NotBefore: notBefore.UTC(),
|
|
})
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "creating new job")
|
|
}
|
|
|
|
_, err = tx.ExecContext(
|
|
ctx,
|
|
batchActivitiesStmt,
|
|
batchExecID,
|
|
job.ID,
|
|
scriptID,
|
|
userID,
|
|
fleet.ScheduledBatchExecutionScheduled,
|
|
fleet.BatchExecutionActivityScript,
|
|
len(hostIDs),
|
|
)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "inserting new batch activity")
|
|
}
|
|
|
|
args := make([]map[string]any, 0, len(hostIDs))
|
|
|
|
for _, hostID := range hostIDs {
|
|
args = append(args, map[string]any{
|
|
"exec_id": batchExecID,
|
|
"host_id": hostID,
|
|
})
|
|
}
|
|
|
|
if _, err := sqlx.NamedExecContext(ctx, tx, batchHostsStmt, args); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "inserting batch host results")
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return "", ctxerr.Wrap(ctx, err, "creating scheduled script execution")
|
|
}
|
|
|
|
return batchExecID, nil
|
|
}
|
|
|
|
func (ds *Datastore) CancelBatchScript(ctx context.Context, executionID string) error {
|
|
stmt := `
|
|
SELECT
|
|
bahr.host_execution_id,
|
|
bahr.host_id
|
|
FROM
|
|
batch_activity_host_results bahr
|
|
LEFT JOIN
|
|
host_script_results hsr ON bahr.host_execution_id = hsr.execution_id -- I think?
|
|
WHERE
|
|
bahr.batch_execution_id = ?
|
|
AND
|
|
hsr.canceled = 0
|
|
AND
|
|
hsr.exit_code IS NULL
|
|
AND
|
|
bahr.error IS NULL`
|
|
|
|
stmtSetCanceled := `
|
|
UPDATE
|
|
batch_activities ba
|
|
SET
|
|
finished_at = NOW(),
|
|
status = 'finished',
|
|
canceled = 1,
|
|
num_canceled = (SELECT COUNT(*) FROM batch_activity_host_results WHERE batch_execution_id = ba.execution_id)
|
|
WHERE
|
|
ba.execution_id = ?`
|
|
|
|
stmtCanceled := `
|
|
UPDATE
|
|
batch_activities
|
|
SET
|
|
canceled = 1
|
|
WHERE
|
|
execution_id = ?`
|
|
|
|
activity, err := ds.GetBatchActivity(ctx, executionID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "getting batch activity")
|
|
}
|
|
|
|
if activity.Status == fleet.ScheduledBatchExecutionFinished {
|
|
return nil
|
|
}
|
|
|
|
if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error {
|
|
// If job worker exists, mark it as complete to stop it from running
|
|
if jobID := activity.JobID; jobID != nil {
|
|
job, err := ds.GetJob(ctx, *jobID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "failed to find job associated with batch activity")
|
|
}
|
|
|
|
job.State = fleet.JobStateSuccess
|
|
|
|
if _, err := ds.updateJob(ctx, tx, *jobID, job); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "updating batch activity job")
|
|
}
|
|
}
|
|
|
|
if activity.Status == fleet.ScheduledBatchExecutionStarted {
|
|
// If the batch activity has started, we need to cancel anything in progress or queued
|
|
toCancel := []struct {
|
|
HostExecutionID string `db:"host_execution_id"`
|
|
HostID uint `db:"host_id"`
|
|
}{}
|
|
|
|
if err := sqlx.SelectContext(ctx, tx, &toCancel, stmt, executionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "selecting hosts to cancel")
|
|
}
|
|
|
|
for _, host := range toCancel {
|
|
if _, err := ds.cancelHostUpcomingActivity(ctx, tx, host.HostID, host.HostExecutionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "canceling upcoming activity")
|
|
}
|
|
}
|
|
|
|
if _, err := tx.ExecContext(ctx, stmtCanceled, executionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "setting canceled column")
|
|
}
|
|
|
|
if err := ds.markActivitiesAsCompleted(ctx, tx); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "marking job as complete and summarizing counts")
|
|
}
|
|
} else {
|
|
// The batch activity is scheduled, but not started
|
|
if _, err := tx.ExecContext(ctx, stmtSetCanceled, executionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "setting canceled host count")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "cancel batch script db transaction")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) GetBatchActivity(ctx context.Context, executionID string) (*fleet.BatchActivity, error) {
|
|
const stmt = `
|
|
SELECT
|
|
ba.id,
|
|
ba.script_id,
|
|
s.name as script_name,
|
|
ba.execution_id,
|
|
ba.user_id,
|
|
ba.job_id,
|
|
ba.status,
|
|
ba.activity_type,
|
|
ba.num_targeted,
|
|
ba.num_pending,
|
|
ba.num_ran,
|
|
ba.num_errored,
|
|
ba.num_incompatible,
|
|
ba.num_canceled,
|
|
ba.created_at,
|
|
ba.updated_at,
|
|
ba.started_at,
|
|
ba.finished_at,
|
|
ba.canceled
|
|
FROM
|
|
batch_activities ba
|
|
LEFT JOIN
|
|
scripts s ON s.id = ba.script_id
|
|
WHERE
|
|
execution_id = ?`
|
|
|
|
batchActivity := &fleet.BatchActivity{}
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), batchActivity, stmt, executionID); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "selecting batch activity")
|
|
}
|
|
|
|
return batchActivity, nil
|
|
}
|
|
|
|
func (ds *Datastore) GetBatchActivityHostResults(ctx context.Context, executionID string) ([]*fleet.BatchActivityHostResult, error) {
|
|
const stmt = `
|
|
SELECT
|
|
id,
|
|
batch_execution_id,
|
|
host_id,
|
|
host_execution_id,
|
|
error
|
|
FROM
|
|
batch_activity_host_results
|
|
WHERE
|
|
batch_execution_id = ?`
|
|
|
|
results := []*fleet.BatchActivityHostResult{}
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, stmt, executionID); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "selecting batch activity host results")
|
|
}
|
|
|
|
return results, nil
|
|
}
|
|
|
|
func (ds *Datastore) RunScheduledBatchActivity(ctx context.Context, executionID string) error {
|
|
batchActivity, err := ds.GetBatchActivity(ctx, executionID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "getting batch activity")
|
|
}
|
|
|
|
if batchActivity.Status != fleet.ScheduledBatchExecutionScheduled {
|
|
return ctxerr.New(ctx, "batch job has already been started")
|
|
}
|
|
|
|
if batchActivity.Canceled {
|
|
return ctxerr.New(ctx, "batch job was canceled")
|
|
}
|
|
|
|
if batchActivity.ScriptID == nil {
|
|
return ctxerr.New(ctx, "no script ID present in batch activity")
|
|
}
|
|
|
|
script, err := ds.Script(ctx, *batchActivity.ScriptID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "could not get script")
|
|
}
|
|
|
|
results, err := ds.GetBatchActivityHostResults(ctx, executionID)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "getting batch activity host results")
|
|
}
|
|
|
|
hostIDs := []uint{}
|
|
for _, result := range results {
|
|
hostIDs = append(hostIDs, result.HostID)
|
|
}
|
|
|
|
if err := ds.batchExecuteScript(ctx, batchActivity.UserID, script.ID, hostIDs, batchActivity.BatchExecutionID); err != nil {
|
|
return ctxerr.Wrap(ctx, err, "scheduled batch script execution")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Deprecated; will be removed in favor of ListBatchScriptExecutions when the batch script details page is ready.
|
|
func (ds *Datastore) BatchExecuteSummary(ctx context.Context, executionID string) (*fleet.BatchActivity, error) {
|
|
stmtExecutions := `
|
|
SELECT
|
|
COUNT(*) as num_targeted,
|
|
COUNT(bsehr.error) as num_did_not_run,
|
|
COUNT(CASE WHEN hsr.exit_code = 0 THEN 1 END) as num_succeeded,
|
|
COUNT(CASE WHEN hsr.exit_code <> 0 THEN 1 END) as num_failed,
|
|
COUNT(CASE WHEN hsr.canceled = 1 AND hsr.exit_code IS NULL THEN 1 END) as num_cancelled
|
|
FROM
|
|
batch_activity_host_results bsehr
|
|
LEFT JOIN
|
|
host_script_results hsr
|
|
ON bsehr.host_execution_id = hsr.execution_id
|
|
WHERE
|
|
bsehr.batch_execution_id = ?`
|
|
|
|
stmtScriptDetails := `
|
|
SELECT
|
|
script_id,
|
|
s.name as script_name,
|
|
s.team_id as team_id,
|
|
bse.created_at as created_at
|
|
FROM
|
|
batch_activities bse
|
|
JOIN
|
|
scripts s
|
|
ON bse.script_id = s.id
|
|
WHERE
|
|
bse.execution_id = ?`
|
|
|
|
var summary fleet.BatchActivity
|
|
var temp_summary struct {
|
|
NumTargeted uint `db:"num_targeted"`
|
|
NumDidNotRun uint `db:"num_did_not_run"`
|
|
NumSucceeded uint `db:"num_succeeded"`
|
|
NumFailed uint `db:"num_failed"`
|
|
NumCancelled uint `db:"num_cancelled"`
|
|
}
|
|
// Fill out the execution details
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &temp_summary, stmtExecutions, executionID); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary")
|
|
}
|
|
|
|
summary.NumTargeted = &temp_summary.NumTargeted
|
|
// NumRan is the number of hosts that actually ran the script successfully.
|
|
summary.NumRan = &temp_summary.NumSucceeded
|
|
// NumErrored is the number of hosts that errored out, which includes
|
|
// both failed and did not run.
|
|
summary.NumErrored = ptr.Uint(temp_summary.NumFailed + temp_summary.NumDidNotRun)
|
|
// NumFailed is the number of hosts that were canceled before execution.
|
|
summary.NumCanceled = &temp_summary.NumCancelled
|
|
// NumPending is the number of hosts that are pending execution.
|
|
summary.NumPending = ptr.Uint(temp_summary.NumTargeted - (temp_summary.NumSucceeded + temp_summary.NumFailed + temp_summary.NumDidNotRun + temp_summary.NumCancelled))
|
|
|
|
// Fill out the script details
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &summary, stmtScriptDetails, executionID); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "selecting script information for bulk execution summary")
|
|
}
|
|
|
|
if summary.TeamID == nil {
|
|
summary.TeamID = ptr.Uint(0)
|
|
}
|
|
|
|
return &summary, nil
|
|
}
|
|
|
|
func (ds *Datastore) ListBatchScriptExecutions(ctx context.Context, filter fleet.BatchExecutionStatusFilter) ([]fleet.BatchActivity, error) {
|
|
stmtExecutions := `
|
|
SELECT *
|
|
FROM (
|
|
-- If batch is finished, get the cached host result counts
|
|
SELECT
|
|
COALESCE(ba.num_targeted, 0) AS num_targeted,
|
|
COALESCE(ba.num_incompatible, 0) AS num_incompatible,
|
|
COALESCE(ba.num_ran, 0) AS num_ran,
|
|
COALESCE(ba.num_errored, 0) AS num_errored,
|
|
COALESCE(ba.num_canceled, 0) AS num_canceled,
|
|
COALESCE(ba.num_pending, 0) AS num_pending,
|
|
ba.execution_id,
|
|
ba.script_id,
|
|
ba.status,
|
|
ba.canceled,
|
|
ba.finished_at,
|
|
ba.started_at,
|
|
s.name AS script_name,
|
|
s.global_or_team_id AS team_id,
|
|
ba.created_at AS created_at,
|
|
j.not_before AS not_before,
|
|
ba.id AS id
|
|
FROM batch_activities ba
|
|
JOIN scripts s ON ba.script_id = s.id
|
|
LEFT JOIN jobs j ON j.id = ba.job_id
|
|
WHERE ( %s ) AND ba.status = 'finished'
|
|
|
|
UNION ALL
|
|
|
|
-- If batch is not finished, calculate the host result counts live.
|
|
SELECT
|
|
COUNT(bahr.host_id) AS num_targeted,
|
|
COUNT(bahr.error) AS num_incompatible,
|
|
COUNT(IF(hsr.exit_code = 0, 1, NULL)) AS num_ran,
|
|
COUNT(IF(hsr.exit_code <> 0, 1, NULL)) AS num_errored,
|
|
COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba.canceled = 1), 1, NULL)) AS num_cancelled,
|
|
(
|
|
COUNT(bahr.host_id)
|
|
- COUNT(bahr.error)
|
|
- COUNT(IF(hsr.exit_code = 0, 1, NULL))
|
|
- COUNT(IF(hsr.exit_code <> 0, 1, NULL))
|
|
- COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba.canceled = 1), 1, NULL))
|
|
) AS num_pending,
|
|
ba.execution_id,
|
|
ba.script_id,
|
|
ba.status,
|
|
ba.canceled,
|
|
ba.finished_at,
|
|
ba.started_at,
|
|
s.name AS script_name,
|
|
s.global_or_team_id AS team_id,
|
|
ba.created_at AS created_at,
|
|
j.not_before AS not_before,
|
|
ba.id AS id
|
|
FROM batch_activities ba
|
|
LEFT JOIN batch_activity_host_results bahr
|
|
ON ba.execution_id = bahr.batch_execution_id
|
|
LEFT JOIN host_script_results hsr
|
|
ON bahr.host_execution_id = hsr.execution_id
|
|
JOIN scripts s
|
|
ON ba.script_id = s.id
|
|
LEFT JOIN jobs j
|
|
ON j.id = ba.job_id
|
|
WHERE ( %s ) AND ba.status <> 'finished'
|
|
GROUP BY ba.id
|
|
) AS u
|
|
ORDER BY
|
|
%s
|
|
LIMIT %d OFFSET %d
|
|
`
|
|
limit := 10
|
|
offset := 0
|
|
args := []any{}
|
|
orderBy := []string{"u.created_at DESC", "u.id DESC"}
|
|
whereClauses := make([]string, 0, 2)
|
|
// If an execution ID is provided, use it to filter the results.
|
|
if filter.ExecutionID != nil && *filter.ExecutionID != "" {
|
|
whereClauses = append(whereClauses, "ba.execution_id = ?")
|
|
args = append(args, *filter.ExecutionID)
|
|
} else {
|
|
// Otherwise filter by status and/or team ID.
|
|
if filter.Status != nil && *filter.Status != "" {
|
|
whereClauses = append(whereClauses, "ba.status = ?")
|
|
args = append(args, *filter.Status)
|
|
switch *filter.Status {
|
|
case string(fleet.ScheduledBatchExecutionScheduled):
|
|
orderBy = append([]string{"u.not_before ASC"}, orderBy...)
|
|
case string(fleet.ScheduledBatchExecutionStarted):
|
|
orderBy = append([]string{"u.started_at DESC"}, orderBy...)
|
|
case string(fleet.ScheduledBatchExecutionFinished):
|
|
orderBy = append([]string{"u.finished_at DESC"}, orderBy...)
|
|
default:
|
|
// no additional ordering
|
|
}
|
|
}
|
|
if filter.TeamID != nil {
|
|
whereClauses = append(whereClauses, "s.global_or_team_id = ?")
|
|
args = append(args, *filter.TeamID)
|
|
}
|
|
}
|
|
|
|
// Double up the args to use them in both WHERE clauses.
|
|
args = append(args, args...)
|
|
|
|
// Use pagination parameters if provided.
|
|
if filter.Limit != nil {
|
|
limit = int(*filter.Limit) //nolint:gosec // dismiss G115
|
|
}
|
|
if filter.Offset != nil {
|
|
offset = int(*filter.Offset) //nolint:gosec // dismiss G115
|
|
}
|
|
where := strings.Join(whereClauses, " AND ")
|
|
stmtExecutions = fmt.Sprintf(stmtExecutions, where, where, strings.Join(orderBy, ", "), limit, offset)
|
|
var summary []fleet.BatchActivity
|
|
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &summary, stmtExecutions, args...); err != nil {
|
|
return nil, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary")
|
|
}
|
|
|
|
return summary, nil
|
|
}
|
|
|
|
func (ds *Datastore) CountBatchScriptExecutions(ctx context.Context, filter fleet.BatchExecutionStatusFilter) (int64, error) {
|
|
stmtExecutions := `
|
|
SELECT
|
|
COUNT(*)
|
|
FROM
|
|
batch_activities ba
|
|
JOIN
|
|
scripts s
|
|
ON ba.script_id = s.id
|
|
WHERE
|
|
%s
|
|
`
|
|
args := []any{}
|
|
whereClauses := make([]string, 0, 2)
|
|
if filter.Status != nil && *filter.Status != "" {
|
|
whereClauses = append(whereClauses, "ba.status = ?")
|
|
args = append(args, *filter.Status)
|
|
}
|
|
if filter.TeamID != nil {
|
|
whereClauses = append(whereClauses, "s.global_or_team_id = ?")
|
|
args = append(args, *filter.TeamID)
|
|
}
|
|
where := strings.Join(whereClauses, " AND ")
|
|
stmtExecutions = fmt.Sprintf(stmtExecutions, where)
|
|
|
|
var count int64
|
|
if err := sqlx.GetContext(ctx, ds.reader(ctx), &count, stmtExecutions, args...); err != nil {
|
|
return 0, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary")
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
func (ds *Datastore) markActivitiesAsCompleted(ctx context.Context, tx sqlx.ExtContext) error {
|
|
const stmt = `
|
|
UPDATE batch_activities AS ba
|
|
JOIN (
|
|
SELECT
|
|
ba2.id AS batch_id,
|
|
COUNT(bahr.host_id) AS num_targeted,
|
|
COUNT(bahr.error) AS num_incompatible,
|
|
COUNT(IF(hsr.exit_code = 0, 1, NULL)) AS num_ran,
|
|
COUNT(IF(hsr.exit_code <> 0, 1, NULL)) AS num_errored,
|
|
COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba2.canceled = 1), 1, NULL)) AS num_canceled
|
|
FROM batch_activities AS ba2
|
|
LEFT JOIN batch_activity_host_results AS bahr
|
|
ON ba2.execution_id = bahr.batch_execution_id
|
|
LEFT JOIN host_script_results AS hsr
|
|
ON bahr.host_execution_id = hsr.execution_id
|
|
WHERE ba2.status = 'started'
|
|
GROUP BY ba2.id
|
|
HAVING (num_incompatible + num_ran + num_errored + num_canceled) >= num_targeted
|
|
) AS agg
|
|
ON agg.batch_id = ba.id
|
|
SET
|
|
ba.status = 'finished',
|
|
ba.finished_at = NOW(),
|
|
ba.num_targeted = agg.num_targeted,
|
|
ba.num_incompatible = agg.num_incompatible,
|
|
ba.num_ran = agg.num_ran,
|
|
ba.num_errored = agg.num_errored,
|
|
ba.num_canceled = agg.num_canceled,
|
|
ba.num_pending = 0
|
|
WHERE ba.status = 'started';
|
|
`
|
|
// TODO -- use `RETURNING` to return the IDs of the updated activities?
|
|
_, err := tx.ExecContext(ctx, stmt)
|
|
if err != nil {
|
|
return ctxerr.Wrap(ctx, err, "marking activities as completed")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ds *Datastore) MarkActivitiesAsCompleted(ctx context.Context) error {
|
|
return ds.markActivitiesAsCompleted(ctx, ds.writer(ctx))
|
|
}
|