package mysql import ( "context" "crypto/md5" //nolint:gosec "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "strings" "time" "unicode/utf8" constants "github.com/fleetdm/fleet/v4/pkg/scripts" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/ptr" "github.com/go-kit/log/level" "github.com/google/uuid" "github.com/jmoiron/sqlx" ) func (ds *Datastore) NewHostScriptExecutionRequest(ctx context.Context, request *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) { var res *fleet.HostScriptResult return res, ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var err error if request.ScriptContentID == 0 { // then we are doing a sync execution, so create the contents first scRes, err := insertScriptContents(ctx, tx, request.ScriptContents) if err != nil { return err } id, _ := scRes.LastInsertId() request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115 } res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, false) return err }) } func (ds *Datastore) newHostScriptExecutionRequest(ctx context.Context, tx sqlx.ExtContext, request *fleet.HostScriptRequestPayload, isInternal bool) (*fleet.HostScriptResult, error) { const ( getStmt = ` SELECT ua.id, ua.host_id, ua.execution_id, ua.created_at, sua.script_id, sua.policy_id, ua.user_id, payload->'$.sync_request' AS sync_request, sc.contents as script_contents, sua.setup_experience_script_id FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id INNER JOIN script_contents sc ON sua.script_content_id = sc.id WHERE ua.id = ? ` ) _, activityID, err := ds.insertNewHostScriptExecution(ctx, tx, request, isInternal) if err != nil { return nil, ctxerr.Wrap(ctx, err, "inserting new script execution request") } var script fleet.HostScriptResult err = sqlx.GetContext(ctx, tx, &script, getStmt, activityID) if err != nil { return nil, ctxerr.Wrap(ctx, err, "getting the created host script activity to return") } return &script, nil } func (ds *Datastore) insertNewHostScriptExecution(ctx context.Context, tx sqlx.ExtContext, request *fleet.HostScriptRequestPayload, isInternal bool) (string, int64, error) { const ( insUAStmt = ` INSERT INTO upcoming_activities (host_id, priority, user_id, fleet_initiated, activity_type, execution_id, payload) VALUES (?, ?, ?, ?, 'script', ?, JSON_OBJECT( 'sync_request', ?, 'is_internal', ?, 'user', (SELECT JSON_OBJECT('name', name, 'email', email, 'gravatar_url', gravatar_url) FROM users WHERE id = ?) ) )` insSUAStmt = ` INSERT INTO script_upcoming_activities (upcoming_activity_id, script_id, script_content_id, policy_id, setup_experience_script_id) VALUES (?, ?, ?, ?, ?) ` ) execID := uuid.New().String() result, err := tx.ExecContext(ctx, insUAStmt, request.HostID, request.Priority(), request.UserID, request.PolicyID != nil, // fleet-initiated if request is via a policy failure execID, request.SyncRequest, isInternal, request.UserID, ) if err != nil { return "", 0, ctxerr.Wrap(ctx, err, "new script upcoming activity") } activityID, _ := result.LastInsertId() _, err = tx.ExecContext(ctx, insSUAStmt, activityID, request.ScriptID, request.ScriptContentID, request.PolicyID, request.SetupExperienceScriptID, ) if err != nil { return "", 0, ctxerr.Wrap(ctx, err, "new join script upcoming activity") } if _, err := ds.activateNextUpcomingActivity(ctx, tx, request.HostID, ""); err != nil { return "", 0, ctxerr.Wrap(ctx, err, "activate next activity") } return execID, activityID, nil } func truncateScriptResult(output string) string { const maxOutputRuneLen = 10000 if len(output) > utf8.UTFMax*maxOutputRuneLen { // truncate the bytes as we know the output is too long, no point // converting more bytes than needed to runes. output = output[len(output)-(utf8.UTFMax*maxOutputRuneLen):] } if utf8.RuneCountInString(output) > maxOutputRuneLen { outputRunes := []rune(output) output = string(outputRunes[len(outputRunes)-maxOutputRuneLen:]) } return output } func (ds *Datastore) SetHostScriptExecutionResult(ctx context.Context, result *fleet.HostScriptResultPayload) (*fleet.HostScriptResult, string, error, ) { const resultExistsStmt = ` SELECT 1 FROM host_script_results WHERE host_id = ? AND execution_id = ? AND exit_code IS NOT NULL ` const updStmt = ` UPDATE host_script_results SET output = ?, runtime = ?, exit_code = ?, timeout = ? WHERE host_id = ? AND execution_id = ?` const hostMDMActionsStmt = ` SELECT 'uninstall' AS action FROM host_software_installs WHERE execution_id = :execution_id AND host_id = :host_id UNION -- host_mdm_actions query (and thus row in union) must be last to avoid #25144 SELECT CASE WHEN lock_ref = :execution_id THEN 'lock_ref' WHEN unlock_ref = :execution_id THEN 'unlock_ref' WHEN wipe_ref = :execution_id THEN 'wipe_ref' ELSE '' END AS action FROM host_mdm_actions WHERE host_id = :host_id ` output := truncateScriptResult(result.Output) var hsr *fleet.HostScriptResult var action string err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var resultExists bool err := sqlx.GetContext(ctx, tx, &resultExists, resultExistsStmt, result.HostID, result.ExecutionID) if err != nil && !errors.Is(err, sql.ErrNoRows) { return ctxerr.Wrap(ctx, err, "check if host script result exists") } if resultExists { level.Debug(ds.logger).Log("msg", "duplicate script execution result sent, will be ignored (original result is preserved)", "host_id", result.HostID, "execution_id", result.ExecutionID, ) // still do the activate next activity to ensure progress as there was // an unexpected flow if we get here. if _, err := ds.activateNextUpcomingActivity(ctx, tx, result.HostID, result.ExecutionID); err != nil { return ctxerr.Wrap(ctx, err, "activate next activity") } // succeed but leave hsr nil return nil } res, err := tx.ExecContext(ctx, updStmt, output, result.Runtime, // Windows error codes are signed 32-bit integers, but are // returned as unsigned integers by the windows API. The // software that receives them is responsible for casting // it to a 32-bit signed integer. // See /orbit/pkg/scripts/exec_windows.go int32(result.ExitCode), //nolint:gosec // dismiss G115 result.Timeout, result.HostID, result.ExecutionID, ) if err != nil { return ctxerr.Wrap(ctx, err, "update host script result") } if n, _ := res.RowsAffected(); n > 0 { // it did update, so return the updated result hsr, err = ds.getHostScriptExecutionResultDB(ctx, tx, result.ExecutionID, scriptExecutionSearchOpts{IncludeCanceled: true}) if err != nil { return ctxerr.Wrap(ctx, err, "load updated host script result") } // look up if that script was a lock/unlock/wipe/uninstall script for that host, // and if so update the host_mdm_actions table accordingly. namedArgs := map[string]any{ "host_id": result.HostID, "execution_id": result.ExecutionID, } stmt, args, err := sqlx.Named(hostMDMActionsStmt, namedArgs) if err != nil { return ctxerr.Wrap(ctx, err, "build named query for host mdm actions") } err = sqlx.GetContext(ctx, tx, &action, stmt, args...) if err != nil && !errors.Is(err, sql.ErrNoRows) { // ignore ErrNoRows, refCol will be empty return ctxerr.Wrap(ctx, err, "lookup host script corresponding mdm action") } switch action { case "": // do nothing case "uninstall": err = ds.updateUninstallStatusFromResult(ctx, tx, result.HostID, result.ExecutionID, result.ExitCode) if err != nil { return ctxerr.Wrap(ctx, err, "update host uninstall action based on script result") } default: // lock/unlock/wipe err = updateHostLockWipeStatusFromResult(ctx, tx, result.HostID, action, result.ExitCode == 0) if err != nil { return ctxerr.Wrap(ctx, err, "update host mdm action based on script result") } } } if _, err := ds.activateNextUpcomingActivity(ctx, tx, result.HostID, result.ExecutionID); err != nil { return ctxerr.Wrap(ctx, err, "activate next activity") } return nil }) if err != nil { return nil, "", err } return hsr, action, nil } func (ds *Datastore) ListPendingHostScriptExecutions(ctx context.Context, hostID uint, onlyShowInternal bool) ([]*fleet.HostScriptResult, error) { return ds.listUpcomingHostScriptExecutions(ctx, hostID, onlyShowInternal, false) } func (ds *Datastore) ListReadyToExecuteScriptsForHost(ctx context.Context, hostID uint, onlyShowInternal bool) ([]*fleet.HostScriptResult, error) { return ds.listUpcomingHostScriptExecutions(ctx, hostID, onlyShowInternal, true) } func (ds *Datastore) listUpcomingHostScriptExecutions(ctx context.Context, hostID uint, onlyShowInternal, onlyReadyToExecute bool) ([]*fleet.HostScriptResult, error) { extraWhere := "" if onlyShowInternal { // software_uninstalls are implicitly internal extraWhere = " AND COALESCE(ua.payload->'$.is_internal', 1) = 1" } if onlyReadyToExecute { extraWhere += " AND ua.activated_at IS NOT NULL" } // this selects software uninstalls too as they run as scripts listStmt := fmt.Sprintf(` SELECT id, host_id, execution_id, script_id, created_at FROM ( SELECT ua.id, ua.host_id, ua.execution_id, sua.script_id, ua.priority, ua.created_at, IF(ua.activated_at IS NULL, 0, 1) AS topmost FROM upcoming_activities ua -- left join because software_uninstall has no script join LEFT JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.host_id = ? AND ua.activity_type IN ('script', 'software_uninstall') %s ORDER BY topmost DESC, priority DESC, created_at ASC) t`, extraWhere) var results []*fleet.HostScriptResult if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, listStmt, hostID); err != nil { return nil, ctxerr.Wrap(ctx, err, "list pending host script executions") } return results, nil } func (ds *Datastore) IsExecutionPendingForHost(ctx context.Context, hostID uint, scriptID uint) (bool, error) { const getStmt = ` SELECT 1 FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.host_id = ? AND ua.activity_type = 'script' AND sua.script_id = ? ` var results []*uint if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, getStmt, hostID, scriptID); err != nil { return false, ctxerr.Wrap(ctx, err, "is execution pending for host") } return len(results) > 0, nil } type scriptExecutionSearchOpts struct { IncludeCanceled bool UninstallHostID uint } func (ds *Datastore) GetHostScriptExecutionResult(ctx context.Context, execID string) (*fleet.HostScriptResult, error) { return ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), execID, scriptExecutionSearchOpts{}) } func (ds *Datastore) GetSelfServiceUninstallScriptExecutionResult(ctx context.Context, execID string, hostID uint) (*fleet.HostScriptResult, error) { return ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), execID, scriptExecutionSearchOpts{UninstallHostID: hostID}) } func (ds *Datastore) getHostScriptExecutionResultDB(ctx context.Context, q sqlx.QueryerContext, execID string, opts scriptExecutionSearchOpts) (*fleet.HostScriptResult, error) { var activeParams []any canceledCondition := "" if !opts.IncludeCanceled { canceledCondition = " AND hsr.canceled = 0" } uninstallCondition := "" if opts.UninstallHostID > 0 { uninstallCondition = `JOIN host_software_installs hsi ON hsi.execution_id = hsr.execution_id AND hsi.uninstall = TRUE AND hsr.host_id = ?` activeParams = append(activeParams, opts.UninstallHostID) } activeParams = append(activeParams, execID) getActiveStmt := fmt.Sprintf(` SELECT hsr.id, hsr.host_id, hsr.execution_id, sc.contents as script_contents, hsr.script_id, hsr.policy_id, hsr.output, hsr.runtime, hsr.exit_code, hsr.timeout, hsr.created_at, hsr.user_id, hsr.sync_request, hsr.host_deleted_at, hsr.setup_experience_script_id, hsr.canceled, bahr.batch_execution_id FROM host_script_results hsr LEFT JOIN batch_activity_host_results bahr ON hsr.execution_id = bahr.host_execution_id JOIN script_contents sc %s WHERE hsr.execution_id = ? AND hsr.script_content_id = sc.id %s `, uninstallCondition, canceledCondition) // We don't include upcoming uninstall script executions in results (different activity type, and they're blank anyway) const getUpcomingStmt = ` SELECT 0 as id, ua.host_id, ua.execution_id, sc.contents as script_contents, sua.script_id, sua.policy_id, '' as output, 0 as runtime, NULL as exit_code, NULL as timeout, ua.created_at, ua.user_id, COALESCE(ua.payload->'$.sync_request', 0) as sync_request, NULL as host_deleted_at, sua.setup_experience_script_id, 0 as canceled FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id INNER JOIN script_contents sc ON sua.script_content_id = sc.id WHERE ua.execution_id = ? AND ua.activity_type = 'script' ` var result fleet.HostScriptResult if err := sqlx.GetContext(ctx, q, &result, getActiveStmt, activeParams...); err != nil { if errors.Is(err, sql.ErrNoRows) { // try with upcoming activities err = sqlx.GetContext(ctx, q, &result, getUpcomingStmt, execID) if errors.Is(err, sql.ErrNoRows) { return nil, ctxerr.Wrap(ctx, notFound("HostScriptResult").WithName(execID)) } } if err != nil { return nil, ctxerr.Wrap(ctx, err, "get host script result") } } return &result, nil } func (ds *Datastore) NewScript(ctx context.Context, script *fleet.Script) (*fleet.Script, error) { var res sql.Result err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var err error // first insert script contents scRes, err := insertScriptContents(ctx, tx, script.ScriptContents) if err != nil { return err } id, _ := scRes.LastInsertId() // then create the script entity res, err = insertScript(ctx, tx, script, uint(id)) //nolint:gosec // dismiss G115 return err }) if err != nil { return nil, err } id, _ := res.LastInsertId() return ds.getScriptDB(ctx, ds.writer(ctx), uint(id)) //nolint:gosec // dismiss G115 } func (ds *Datastore) UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error) { err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { // Get the current script_content_id var oldContentID int64 getCurrentStmt := `SELECT script_content_id FROM scripts WHERE id = ?` err := sqlx.GetContext(ctx, tx, &oldContentID, getCurrentStmt, scriptID) if err != nil { return ctxerr.Wrap(ctx, err, "getting current script content id") } // Insert or get existing content (insertScriptContents handles deduplication) scRes, err := insertScriptContents(ctx, tx, scriptContents) if err != nil { return ctxerr.Wrap(ctx, err, "inserting/getting script contents") } newContentID, _ := scRes.LastInsertId() // Update the script to point to the new content if newContentID != oldContentID { updateStmt := ` UPDATE scripts SET script_content_id = ? WHERE id = ? ` _, err = tx.ExecContext(ctx, updateStmt, newContentID, scriptID) if err != nil { return ctxerr.Wrap(ctx, err, "updating script content reference") } // Try to clean up the old content if no longer used // Don't fail the transaction if cleanup fails; just log it if err := ds.cleanupScriptContent(ctx, tx, uint(oldContentID)); err != nil { //nolint:gosec level.Error(ds.logger).Log("msg", "failed to cleanup orphaned script content", "script_id", scriptID, "old_content_id", oldContentID, "err", err) ctxerr.Handle(ctx, err) } } else { // Just update the timestamp _, err = tx.ExecContext(ctx, "UPDATE scripts SET updated_at = NOW() WHERE id = ?", scriptID) if err != nil { return ctxerr.Wrap(ctx, err, "updating script updated_at time") } } // Cancel pending executions if err := ds.cancelUpcomingScriptActivities(ctx, tx, scriptID); err != nil { return ctxerr.Wrap(ctx, err, "canceling upcoming script executions") } return nil }) if err != nil { return nil, ctxerr.Wrap(ctx, err, "updating script contents") } return ds.Script(ctx, scriptID) } func (ds *Datastore) cancelUpcomingScriptActivities(ctx context.Context, db sqlx.ExtContext, scriptID uint) error { const stmt = ` SELECT ua.execution_id, ua.host_id FROM script_upcoming_activities sua INNER JOIN upcoming_activities ua ON ua.id = sua.upcoming_activity_id WHERE sua.script_id = ? ` var upcomingExecutions []struct { ExecutionID string `db:"execution_id"` HostID uint `db:"host_id"` } if err := sqlx.SelectContext(ctx, db, &upcomingExecutions, stmt, scriptID); err != nil { return ctxerr.Wrap(ctx, err, "selecting upcoming script executions") } for _, upcomingExecution := range upcomingExecutions { if _, err := ds.cancelHostUpcomingActivity(ctx, db, upcomingExecution.HostID, upcomingExecution.ExecutionID); err != nil { return ctxerr.Wrap(ctx, err, "canceling upcoming activity") } } return nil } func insertScript(ctx context.Context, tx sqlx.ExtContext, script *fleet.Script, scriptContentsID uint) (sql.Result, error) { const insertStmt = ` INSERT INTO scripts ( team_id, global_or_team_id, name, script_content_id ) VALUES (?, ?, ?, ?) ` var globalOrTeamID uint if script.TeamID != nil { globalOrTeamID = *script.TeamID } res, err := tx.ExecContext(ctx, insertStmt, script.TeamID, globalOrTeamID, script.Name, scriptContentsID) if err != nil { if IsDuplicate(err) { // name already exists for this team/global err = alreadyExists("Script", script.Name) } else if isChildForeignKeyError(err) { // team does not exist err = foreignKey("scripts", fmt.Sprintf("team_id=%v", script.TeamID)) } return nil, ctxerr.Wrap(ctx, err, "insert script") } return res, nil } func insertScriptContents(ctx context.Context, tx sqlx.ExtContext, contents string) (sql.Result, error) { const insertStmt = ` INSERT INTO script_contents ( md5_checksum, contents ) VALUES (UNHEX(?),?) ON DUPLICATE KEY UPDATE id=LAST_INSERT_ID(id) ` md5Checksum := md5ChecksumScriptContent(contents) res, err := tx.ExecContext(ctx, insertStmt, md5Checksum, contents) if err != nil { return nil, ctxerr.Wrap(ctx, err, "insert script contents") } return res, nil } func md5ChecksumScriptContent(s string) string { return md5ChecksumBytes([]byte(s)) } func md5ChecksumBytes(b []byte) string { rawChecksum := md5.Sum(b) //nolint:gosec return strings.ToUpper(hex.EncodeToString(rawChecksum[:])) } func (ds *Datastore) cleanupScriptContent(ctx context.Context, tx sqlx.ExtContext, contentID uint) error { // Check if this content is still being used anywhere var usageCount int stmt := ` SELECT COUNT(*) FROM ( SELECT 1 FROM scripts WHERE script_content_id = ? UNION ALL SELECT 1 FROM setup_experience_scripts WHERE script_content_id = ? UNION ALL SELECT 1 FROM software_installers WHERE install_script_content_id = ? OR uninstall_script_content_id = ? OR post_install_script_content_id = ? UNION ALL SELECT 1 FROM script_upcoming_activities WHERE script_content_id = ? UNION ALL SELECT 1 FROM host_script_results WHERE script_content_id = ? ) t ` err := sqlx.GetContext(ctx, tx, &usageCount, stmt, contentID, contentID, contentID, contentID, contentID, contentID, contentID) if err != nil { return ctxerr.Wrap(ctx, err, "checking script content usage for cleanup") } if usageCount == 0 { // Not being used, safe to delete deleteStmt := `DELETE FROM script_contents WHERE id = ?` _, err = tx.ExecContext(ctx, deleteStmt, contentID) if err != nil { return ctxerr.Wrap(ctx, err, "deleting unused script content") } } return nil } func (ds *Datastore) Script(ctx context.Context, id uint) (*fleet.Script, error) { return ds.getScriptDB(ctx, ds.reader(ctx), id) } func (ds *Datastore) getScriptDB(ctx context.Context, q sqlx.QueryerContext, id uint) (*fleet.Script, error) { const getStmt = ` SELECT id, team_id, name, created_at, updated_at, script_content_id FROM scripts WHERE id = ? ` var script fleet.Script if err := sqlx.GetContext(ctx, q, &script, getStmt, id); err != nil { if err == sql.ErrNoRows { return nil, notFound("Script").WithID(id) } return nil, ctxerr.Wrap(ctx, err, "get script") } return &script, nil } func (ds *Datastore) GetScriptContents(ctx context.Context, id uint) ([]byte, error) { const getStmt = ` SELECT sc.contents FROM script_contents sc JOIN scripts s ON s.script_content_id = sc.id WHERE s.id = ? ` var contents []byte if err := sqlx.GetContext(ctx, ds.reader(ctx), &contents, getStmt, id); err != nil { if err == sql.ErrNoRows { return nil, notFound("Script").WithID(id) } return nil, ctxerr.Wrap(ctx, err, "get script contents") } return contents, nil } func (ds *Datastore) GetAnyScriptContents(ctx context.Context, id uint) ([]byte, error) { const getStmt = ` SELECT sc.contents FROM script_contents sc WHERE sc.id = ? ` var contents []byte if err := sqlx.GetContext(ctx, ds.reader(ctx), &contents, getStmt, id); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, notFound("Script").WithID(id) } return nil, ctxerr.Wrap(ctx, err, "get any script contents") } return contents, nil } var errDeleteScriptWithAssociatedPolicy = &fleet.ConflictError{Message: "Couldn't delete. Policy automation uses this script. Please remove this script from associated policy automations and try again."} func (ds *Datastore) DeleteScript(ctx context.Context, id uint) error { var activateAffectedHosts []uint err := ds.withTx(ctx, func(tx sqlx.ExtContext) error { _, err := tx.ExecContext(ctx, `DELETE FROM host_script_results WHERE script_id = ? AND exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)`, id, int(constants.MaxServerWaitTime.Seconds()), ) if err != nil { return ctxerr.Wrapf(ctx, err, "cancel pending script executions") } // load hosts that will have their upcoming_activities deleted, if that // activity is "activated", as that means we will have to call // activateNextUpcomingActivity for those hosts. loadAffectedHostsStmt := ` SELECT DISTINCT host_id FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE sua.script_id = ? AND ua.activity_type = 'script' AND ua.activated_at IS NOT NULL AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND)` var affectedHosts []uint if err := sqlx.SelectContext(ctx, tx, &affectedHosts, loadAffectedHostsStmt, id, int(constants.MaxServerWaitTime.Seconds())); err != nil { return ctxerr.Wrapf(ctx, err, "load affected hosts") } activateAffectedHosts = affectedHosts _, err = tx.ExecContext(ctx, `DELETE FROM upcoming_activities USING upcoming_activities INNER JOIN script_upcoming_activities sua ON upcoming_activities.id = sua.upcoming_activity_id WHERE sua.script_id = ? AND upcoming_activities.activity_type = 'script' AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND) `, id, int(constants.MaxServerWaitTime.Seconds()), ) if err != nil { return ctxerr.Wrapf(ctx, err, "cancel upcoming pending script executions") } _, err = tx.ExecContext(ctx, `DELETE FROM scripts WHERE id = ?`, id) if err != nil { if isMySQLForeignKey(err) { // Check if the script is referenced by a policy automation. var count int if err := sqlx.GetContext(ctx, tx, &count, `SELECT COUNT(*) FROM policies WHERE script_id = ?`, id); err != nil { return ctxerr.Wrapf(ctx, err, "getting reference from policies") } if count > 0 { return ctxerr.Wrap(ctx, errDeleteScriptWithAssociatedPolicy, "delete script") } } return ctxerr.Wrap(ctx, err, "delete script") } return nil }) if err != nil { return err } // we call this outside of the transaction to avoid a // long-running/deadlock-prone transaction, as many hosts could be affected. return ds.activateNextUpcomingActivityForBatchOfHosts(ctx, activateAffectedHosts) } // deletePendingHostScriptExecutionsForPolicy should be called when a policy is deleted to remove any pending script executions func (ds *Datastore) deletePendingHostScriptExecutionsForPolicy(ctx context.Context, teamID *uint, policyID uint) error { var globalOrTeamID uint if teamID != nil { globalOrTeamID = *teamID } deletePendingFunc := func(stmt string, args ...any) error { _, err := ds.writer(ctx).ExecContext(ctx, stmt, args...) return ctxerr.Wrap(ctx, err, "delete pending host script executions for policy") } deleteHSRStmt := ` DELETE FROM host_script_results WHERE policy_id = ? AND script_id IN ( SELECT id FROM scripts WHERE scripts.global_or_team_id = ? ) AND exit_code IS NULL ` if err := deletePendingFunc(deleteHSRStmt, policyID, globalOrTeamID); err != nil { return err } loadAffectedHostsStmt := ` SELECT DISTINCT host_id FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.activity_type = 'script' AND ua.activated_at IS NOT NULL AND sua.policy_id = ? AND sua.script_id IN ( SELECT id FROM scripts WHERE scripts.global_or_team_id = ? )` var affectedHosts []uint if err := sqlx.SelectContext(ctx, ds.reader(ctx), &affectedHosts, loadAffectedHostsStmt, policyID, globalOrTeamID); err != nil { return err } deleteUAStmt := ` DELETE FROM upcoming_activities USING upcoming_activities INNER JOIN script_upcoming_activities sua ON upcoming_activities.id = sua.upcoming_activity_id WHERE upcoming_activities.activity_type = 'script' AND sua.policy_id = ? AND sua.script_id IN ( SELECT id FROM scripts WHERE scripts.global_or_team_id = ? ) ` if err := deletePendingFunc(deleteUAStmt, policyID, globalOrTeamID); err != nil { return err } return ds.activateNextUpcomingActivityForBatchOfHosts(ctx, affectedHosts) } func (ds *Datastore) ListScripts(ctx context.Context, teamID *uint, opt fleet.ListOptions) ([]*fleet.Script, *fleet.PaginationMetadata, error) { var scripts []*fleet.Script const selectStmt = ` SELECT s.id, s.team_id, s.name, s.created_at, s.updated_at FROM scripts s WHERE s.global_or_team_id = ? ` var globalOrTeamID uint if teamID != nil { globalOrTeamID = *teamID } args := []any{globalOrTeamID} stmt, args := appendListOptionsWithCursorToSQL(selectStmt, args, &opt) if err := sqlx.SelectContext(ctx, ds.reader(ctx), &scripts, stmt, args...); err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "select scripts") } var metaData *fleet.PaginationMetadata if opt.IncludeMetadata { metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0} if len(scripts) > int(opt.PerPage) { //nolint:gosec // dismiss G115 metaData.HasNextResults = true scripts = scripts[:len(scripts)-1] } } return scripts, metaData, nil } func (ds *Datastore) GetScriptIDByName(ctx context.Context, name string, teamID *uint) (uint, error) { const selectStmt = ` SELECT id FROM scripts WHERE global_or_team_id = ? AND name = ? ` var globalOrTeamID uint if teamID != nil { globalOrTeamID = *teamID } var id uint if err := sqlx.GetContext(ctx, ds.reader(ctx), &id, selectStmt, globalOrTeamID, name); err != nil { if err == sql.ErrNoRows { return 0, notFound("Script").WithName(name) } return 0, ctxerr.Wrap(ctx, err, "get script by name") } return id, nil } func (ds *Datastore) GetHostScriptDetails(ctx context.Context, hostID uint, teamID *uint, opt fleet.ListOptions, hostPlatform string) ([]*fleet.HostScriptDetail, *fleet.PaginationMetadata, error) { var globalOrTeamID uint if teamID != nil { globalOrTeamID = *teamID } var extension string switch { case hostPlatform == "windows": // filter by .ps1 extension extension = `%.ps1` case fleet.IsUnixLike(hostPlatform): // filter by .sh extension extension = `%.sh` default: // no extension filter } type row struct { ScriptID uint `db:"script_id"` Name string `db:"name"` HSRID *uint `db:"hsr_id"` ExecutionID *string `db:"execution_id"` ExecutedAt *time.Time `db:"executed_at"` ExitCode *int64 `db:"exit_code"` } sql := ` WITH all_latest_activities AS ( -- Use window function to efficiently find the latest execution per script -- This is O(n) (a self-join approach would be O(n²)) SELECT * FROM ( SELECT id, host_id, script_id, execution_id, created_at, exit_code, 'completed' as source, ROW_NUMBER() OVER ( PARTITION BY script_id ORDER BY created_at DESC, id DESC ) AS row_num FROM host_script_results WHERE host_id = ? AND canceled = 0 ) completed_ranked WHERE row_num = 1 UNION ALL -- latest from upcoming_activities SELECT * FROM ( SELECT NULL as id, ua.host_id, sua.script_id, ua.execution_id, ua.created_at, NULL as exit_code, 'upcoming' as source, ROW_NUMBER() OVER ( PARTITION BY sua.script_id ORDER BY ua.created_at DESC, ua.id DESC ) AS row_num FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.host_id = ? AND ua.activity_type = 'script' ) upcoming_ranked WHERE row_num = 1 ) SELECT s.id AS script_id, s.name, latest.id AS hsr_id, latest.created_at AS executed_at, latest.execution_id, latest.exit_code FROM scripts s LEFT JOIN ( -- Pick the most recent between completed and upcoming for each script SELECT * FROM ( SELECT *, ROW_NUMBER() OVER ( PARTITION BY script_id ORDER BY CASE WHEN source = 'upcoming' THEN 1 ELSE 2 END, -- Prefer upcoming over completed created_at DESC, id DESC ) AS final_rn FROM all_latest_activities ) final_ranked WHERE final_rn = 1 ) latest ON s.id = latest.script_id WHERE (latest.host_id IS NULL OR latest.host_id = ?) AND s.global_or_team_id = ? ` args := []any{hostID, hostID, hostID, globalOrTeamID} if len(extension) > 0 { args = append(args, extension) sql += ` AND s.name LIKE ? ` } stmt, args := appendListOptionsWithCursorToSQL(sql, args, &opt) var rows []*row if err := sqlx.SelectContext(ctx, ds.reader(ctx), &rows, stmt, args...); err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "get host script details") } var metaData *fleet.PaginationMetadata if opt.IncludeMetadata { metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0} if len(rows) > int(opt.PerPage) { //nolint:gosec // dismiss G115 metaData.HasNextResults = true rows = rows[:len(rows)-1] } } results := make([]*fleet.HostScriptDetail, 0, len(rows)) for _, r := range rows { results = append(results, fleet.NewHostScriptDetail(hostID, r.ScriptID, r.Name, r.ExecutionID, r.ExecutedAt, r.ExitCode, r.HSRID)) } return results, metaData, nil } func (ds *Datastore) BatchSetScripts(ctx context.Context, tmID *uint, scripts []*fleet.Script) ([]fleet.ScriptResponse, error) { const loadExistingScripts = ` SELECT name FROM scripts WHERE global_or_team_id = ? AND name IN (?) ` const deleteAllScriptsInTeam = ` DELETE FROM scripts WHERE global_or_team_id = ? ` const unsetAllScriptsFromPolicies = `UPDATE policies SET script_id = NULL WHERE team_id = ?` const clearAllPendingExecutionsHSR = `DELETE FROM host_script_results WHERE exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND) AND script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)` const loadAffectedHostsAllPendingExecutionsUA = ` SELECT DISTINCT host_id FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.activity_type = 'script' AND ua.activated_at IS NOT NULL AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND) AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)` const clearAllPendingExecutionsUA = `DELETE FROM upcoming_activities USING upcoming_activities INNER JOIN script_upcoming_activities sua ON upcoming_activities.id = sua.upcoming_activity_id WHERE upcoming_activities.activity_type = 'script' AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND) AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)` const unsetScriptsNotInListFromPolicies = ` UPDATE policies SET script_id = NULL WHERE script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?)) ` const deleteScriptsNotInList = ` DELETE FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?) ` const clearPendingExecutionsNotInListHSR = `DELETE FROM host_script_results WHERE exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND) AND script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))` const loadAffectedHostsPendingExecutionsNotInListUA = ` SELECT DISTINCT host_id FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.activity_type = 'script' AND ua.activated_at IS NOT NULL AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND) AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))` const clearPendingExecutionsNotInListUA = `DELETE FROM upcoming_activities USING upcoming_activities INNER JOIN script_upcoming_activities sua ON upcoming_activities.id = sua.upcoming_activity_id WHERE upcoming_activities.activity_type = 'script' AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND) AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))` const insertNewOrEditedScript = ` INSERT INTO scripts ( team_id, global_or_team_id, name, script_content_id ) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE script_content_id = VALUES(script_content_id), id=LAST_INSERT_ID(id) ` const clearPendingExecutionsWithObsoleteScriptHSR = `DELETE FROM host_script_results WHERE exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND) AND script_id = ? AND script_content_id != ?` const loadAffectedHostsPendingExecutionsWithObsoleteScriptUA = ` SELECT DISTINCT host_id FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.activity_type = 'script' AND ua.activated_at IS NOT NULL AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND) AND sua.script_id = ? AND sua.script_content_id != ?` const clearPendingExecutionsWithObsoleteScriptUA = `DELETE FROM upcoming_activities USING upcoming_activities INNER JOIN script_upcoming_activities sua ON upcoming_activities.id = sua.upcoming_activity_id WHERE upcoming_activities.activity_type = 'script' AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND) AND sua.script_id = ? AND sua.script_content_id != ?` const loadInsertedScripts = `SELECT id, team_id, name FROM scripts WHERE global_or_team_id = ?` // use a team id of 0 if no-team var globalOrTeamID uint if tmID != nil { globalOrTeamID = *tmID } // build a list of names for the incoming scripts, will keep the // existing ones if there's a match and no change incomingNames := make([]string, len(scripts)) // at the same time, index the incoming scripts keyed by name for ease // of processing incomingScripts := make(map[string]*fleet.Script, len(scripts)) for i, p := range scripts { incomingNames[i] = p.Name incomingScripts[p.Name] = p } var insertedScripts []fleet.ScriptResponse var activateAffectedHosts []uint if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var existingScripts []*fleet.Script if len(incomingNames) > 0 { // load existing scripts that match the incoming scripts by names stmt, args, err := sqlx.In(loadExistingScripts, globalOrTeamID, incomingNames) if err != nil { return ctxerr.Wrap(ctx, err, "build query to load existing scripts") } if err := sqlx.SelectContext(ctx, tx, &existingScripts, stmt, args...); err != nil { return ctxerr.Wrap(ctx, err, "load existing scripts") } } // figure out if we need to delete any scripts keepNames := make([]string, 0, len(incomingNames)) for _, p := range existingScripts { if newS := incomingScripts[p.Name]; newS != nil { keepNames = append(keepNames, p.Name) } } var ( scriptsStmt string scriptsArgs []any policiesStmt string policiesArgs []any executionsStmt string executionsArgs []any extraExecStmt string extraExecArgs []any err error affectedHostIDs []uint ) if len(keepNames) > 0 { // delete the obsolete scripts scriptsStmt, scriptsArgs, err = sqlx.In(deleteScriptsNotInList, globalOrTeamID, keepNames) if err != nil { return ctxerr.Wrap(ctx, err, "build statement to delete obsolete scripts") } policiesStmt, policiesArgs, err = sqlx.In(unsetScriptsNotInListFromPolicies, globalOrTeamID, keepNames) if err != nil { return ctxerr.Wrap(ctx, err, "build statement to unset obsolete scripts from policies") } executionsStmt, executionsArgs, err = sqlx.In(clearPendingExecutionsNotInListHSR, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames) if err != nil { return ctxerr.Wrap(ctx, err, "build statement to clear pending script executions from obsolete scripts") } loadAffectedStmt, args, err := sqlx.In(loadAffectedHostsPendingExecutionsNotInListUA, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames) if err != nil { return ctxerr.Wrap(ctx, err, "build query to load affected hosts for upcoming script executions") } if err := sqlx.SelectContext(ctx, tx, &affectedHostIDs, loadAffectedStmt, args...); err != nil { return ctxerr.Wrap(ctx, err, "load affected hosts for upcoming script executions") } extraExecStmt, extraExecArgs, err = sqlx.In(clearPendingExecutionsNotInListUA, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames) if err != nil { return ctxerr.Wrap(ctx, err, "build statement to clear upcoming pending script executions from obsolete scripts") } } else { scriptsStmt = deleteAllScriptsInTeam scriptsArgs = []any{globalOrTeamID} policiesStmt = unsetAllScriptsFromPolicies policiesArgs = []any{globalOrTeamID} executionsStmt = clearAllPendingExecutionsHSR executionsArgs = []any{int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID} if err := sqlx.SelectContext(ctx, tx, &affectedHostIDs, loadAffectedHostsAllPendingExecutionsUA, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID); err != nil { return ctxerr.Wrap(ctx, err, "load affected hosts for upcoming script executions") } extraExecStmt = clearAllPendingExecutionsUA extraExecArgs = []any{int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID} } if _, err := tx.ExecContext(ctx, policiesStmt, policiesArgs...); err != nil { return ctxerr.Wrap(ctx, err, "unset obsolete scripts from policies") } if _, err := tx.ExecContext(ctx, executionsStmt, executionsArgs...); err != nil { return ctxerr.Wrap(ctx, err, "clear obsolete script pending executions") } if _, err := tx.ExecContext(ctx, extraExecStmt, extraExecArgs...); err != nil { return ctxerr.Wrap(ctx, err, "clear obsolete upcoming script pending executions") } if _, err := tx.ExecContext(ctx, scriptsStmt, scriptsArgs...); err != nil { return ctxerr.Wrap(ctx, err, "delete obsolete scripts") } activateAffectedHosts = affectedHostIDs // insert the new scripts and the ones that have changed for _, s := range incomingScripts { scRes, err := insertScriptContents(ctx, tx, s.ScriptContents) if err != nil { return ctxerr.Wrapf(ctx, err, "inserting script contents for script with name %q", s.Name) } contentID, _ := scRes.LastInsertId() insertRes, err := tx.ExecContext(ctx, insertNewOrEditedScript, tmID, globalOrTeamID, s.Name, uint(contentID)) //nolint:gosec // dismiss G115 if err != nil { return ctxerr.Wrapf(ctx, err, "insert new/edited script with name %q", s.Name) } scriptID, _ := insertRes.LastInsertId() if _, err := tx.ExecContext(ctx, clearPendingExecutionsWithObsoleteScriptHSR, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil { return ctxerr.Wrapf(ctx, err, "clear obsolete pending script executions with name %q", s.Name) } var affectedHosts []uint if err := sqlx.SelectContext(ctx, tx, &affectedHosts, loadAffectedHostsPendingExecutionsWithObsoleteScriptUA, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil { return ctxerr.Wrapf(ctx, err, "load affected hosts for upcoming script executions with name %q", s.Name) } activateAffectedHosts = append(activateAffectedHosts, affectedHosts...) if _, err = tx.ExecContext(ctx, clearPendingExecutionsWithObsoleteScriptUA, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil { return ctxerr.Wrapf(ctx, err, "clear obsolete upcoming pending script executions with name %q", s.Name) } } if err := sqlx.SelectContext(ctx, tx, &insertedScripts, loadInsertedScripts, globalOrTeamID); err != nil { return ctxerr.Wrap(ctx, err, "load inserted scripts") } return nil }); err != nil { return nil, err } if err := ds.activateNextUpcomingActivityForBatchOfHosts(ctx, activateAffectedHosts); err != nil { return nil, ctxerr.Wrap(ctx, err, "activate next upcoming activity for batch of hosts") } return insertedScripts, nil } func (ds *Datastore) GetHostLockWipeStatus(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) { const stmt = ` SELECT lock_ref, wipe_ref, unlock_ref, unlock_pin, fleet_platform FROM host_mdm_actions WHERE host_id = ? ` var mdmActions struct { LockRef *string `db:"lock_ref"` WipeRef *string `db:"wipe_ref"` UnlockRef *string `db:"unlock_ref"` UnlockPIN *string `db:"unlock_pin"` FleetPlatform string `db:"fleet_platform"` } fleetPlatform := host.FleetPlatform() status := &fleet.HostLockWipeStatus{ HostFleetPlatform: fleetPlatform, } if err := sqlx.GetContext(ctx, ds.reader(ctx), &mdmActions, stmt, host.ID); err != nil { if err == sql.ErrNoRows { // do not return a Not Found error, return the zero-value status, which // will report the correct states. return status, nil } return nil, ctxerr.Wrap(ctx, err, "get host lock/wipe status") } // if we have a fleet platform stored in host_mdm_actions, use it instead of // the host.FleetPlatform() because the platform can be overwritten with an // unknown OS name when a Wipe gets executed. if mdmActions.FleetPlatform != "" { fleetPlatform = mdmActions.FleetPlatform status.HostFleetPlatform = fleetPlatform } switch fleetPlatform { case "darwin", "ios", "ipados": if mdmActions.UnlockPIN != nil { status.UnlockPIN = *mdmActions.UnlockPIN } if mdmActions.UnlockRef != nil { var err error status.UnlockRequestedAt, err = time.Parse(time.DateTime, *mdmActions.UnlockRef) if err != nil { // if the format is unexpected but there's something in UnlockRef, just // replace it with the current timestamp, it should still indicate that // an unlock was requested (e.g. in case someone plays with the data // directly in the DB and messes up the format). status.UnlockRequestedAt = time.Now().UTC() } } if mdmActions.LockRef != nil { // the lock reference is an MDM command cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.LockRef, host.UUID) if err != nil { return nil, ctxerr.Wrap(ctx, err, "get lock reference") } status.LockMDMCommand = cmd status.LockMDMCommandResult = cmdRes } if mdmActions.WipeRef != nil { // the wipe reference is an MDM command cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.WipeRef, host.UUID) if err != nil { return nil, ctxerr.Wrap(ctx, err, "get wipe reference") } status.WipeMDMCommand = cmd status.WipeMDMCommandResult = cmdRes } case "windows", "linux": // lock and unlock references are scripts if mdmActions.LockRef != nil { hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.LockRef, scriptExecutionSearchOpts{IncludeCanceled: true}) if err != nil { return nil, ctxerr.Wrap(ctx, err, "get lock reference script result") } status.LockScript = hsr } if mdmActions.UnlockRef != nil { hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.UnlockRef, scriptExecutionSearchOpts{IncludeCanceled: true}) if err != nil { return nil, ctxerr.Wrap(ctx, err, "get unlock reference script result") } status.UnlockScript = hsr } // wipe is an MDM command on Windows, a script on Linux if mdmActions.WipeRef != nil { if fleetPlatform == "windows" { cmd, cmdRes, err := ds.getHostMDMWindowsCommand(ctx, *mdmActions.WipeRef, host.UUID) if err != nil { return nil, ctxerr.Wrap(ctx, err, "get wipe reference") } status.WipeMDMCommand = cmd status.WipeMDMCommandResult = cmdRes } else { hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.WipeRef, scriptExecutionSearchOpts{IncludeCanceled: true}) if err != nil { return nil, ctxerr.Wrap(ctx, err, "get wipe reference script result") } status.WipeScript = hsr } } } return status, nil } func (ds *Datastore) getHostMDMWindowsCommand(ctx context.Context, cmdUUID, hostUUID string) (*fleet.MDMCommand, *fleet.MDMCommandResult, error) { cmd, err := ds.getMDMCommand(ctx, ds.reader(ctx), cmdUUID) if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "get Windows MDM command") } // get the MDM command result, which may be not found (indicating the command // is pending). Note that it doesn't return ErrNoRows if not found, it // returns success and an empty cmdRes slice. cmdResults, err := ds.GetMDMWindowsCommandResults(ctx, cmdUUID) if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "get Windows MDM command result") } // each item in the slice returned by GetMDMWindowsCommandResults is // potentially a result for a different host, we need to find the one for // that specific host. var cmdRes *fleet.MDMCommandResult for _, r := range cmdResults { if r.HostUUID != hostUUID { continue } // all statuses for Windows indicate end of processing of the command // (there is no equivalent of "NotNow" or "Idle" as for Apple). cmdRes = r break } return cmd, cmdRes, nil } func (ds *Datastore) getHostMDMAppleCommand(ctx context.Context, cmdUUID, hostUUID string) (*fleet.MDMCommand, *fleet.MDMCommandResult, error) { cmd, err := ds.getMDMCommand(ctx, ds.reader(ctx), cmdUUID) if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "get Apple MDM command") } // get the MDM command result, which may be not found (indicating the command // is pending). Note that it doesn't return ErrNoRows if not found, it // returns success and an empty cmdRes slice. cmdResults, err := ds.GetMDMAppleCommandResults(ctx, cmdUUID) if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "get Apple MDM command result") } // each item in the slice returned by GetMDMAppleCommandResults is // potentially a result for a different host, we need to find the one for // that specific host. var cmdRes *fleet.MDMCommandResult for _, r := range cmdResults { if r.HostUUID != hostUUID { continue } if r.Status == fleet.MDMAppleStatusAcknowledged || r.Status == fleet.MDMAppleStatusError || r.Status == fleet.MDMAppleStatusCommandFormatError { cmdRes = r break } } return cmd, cmdRes, nil } // LockHostViaScript will create the script execution request and update // host_mdm_actions in a single transaction. func (ds *Datastore) LockHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error { var res *fleet.HostScriptResult return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var err error scRes, err := insertScriptContents(ctx, tx, request.ScriptContents) if err != nil { return err } id, _ := scRes.LastInsertId() request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115 res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true) if err != nil { return ctxerr.Wrap(ctx, err, "lock host via script create execution") } // on duplicate we don't clear any other existing state because at this // point in time, this is just a request to lock the host that is recorded, // it is pending execution. The host's state should be updated to "locked" // only when the script execution is successfully completed, and then any // unlock or wipe references should be cleared. const stmt = ` INSERT INTO host_mdm_actions ( host_id, lock_ref, fleet_platform ) VALUES (?,?,?) ON DUPLICATE KEY UPDATE lock_ref = VALUES(lock_ref) ` _, err = tx.ExecContext(ctx, stmt, request.HostID, res.ExecutionID, hostFleetPlatform, ) if err != nil { return ctxerr.Wrap(ctx, err, "lock host via script update mdm actions") } return nil }) } // UnlockHostViaScript will create the script execution request and update // host_mdm_actions in a single transaction. func (ds *Datastore) UnlockHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error { var res *fleet.HostScriptResult return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var err error scRes, err := insertScriptContents(ctx, tx, request.ScriptContents) if err != nil { return err } id, _ := scRes.LastInsertId() request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115 res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true) if err != nil { return ctxerr.Wrap(ctx, err, "unlock host via script create execution") } // on duplicate we don't clear any other existing state because at this // point in time, this is just a request to unlock the host that is // recorded, it is pending execution. The host's state should be updated to // "unlocked" only when the script execution is successfully completed, and // then any lock or wipe references should be cleared. const stmt = ` INSERT INTO host_mdm_actions ( host_id, unlock_ref, fleet_platform ) VALUES (?,?,?) ON DUPLICATE KEY UPDATE unlock_ref = VALUES(unlock_ref), unlock_pin = NULL ` _, err = tx.ExecContext(ctx, stmt, request.HostID, res.ExecutionID, hostFleetPlatform, ) if err != nil { return ctxerr.Wrap(ctx, err, "unlock host via script update mdm actions") } return err }) } // WipeHostViaScript creates the script execution request and updates the // host_mdm_actions table in a single transaction. func (ds *Datastore) WipeHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error { var res *fleet.HostScriptResult return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var err error scRes, err := insertScriptContents(ctx, tx, request.ScriptContents) if err != nil { return err } id, _ := scRes.LastInsertId() request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115 res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true) if err != nil { return ctxerr.Wrap(ctx, err, "wipe host via script create execution") } // on duplicate we don't clear any other existing state because at this // point in time, this is just a request to wipe the host that is recorded, // it is pending execution, so if it was locked, it is still locked (so the // lock_ref info must still be there). const stmt = ` INSERT INTO host_mdm_actions ( host_id, wipe_ref, fleet_platform ) VALUES (?,?,?) ON DUPLICATE KEY UPDATE wipe_ref = VALUES(wipe_ref) ` _, err = tx.ExecContext(ctx, stmt, request.HostID, res.ExecutionID, hostFleetPlatform, ) if err != nil { return ctxerr.Wrap(ctx, err, "wipe host via script update mdm actions") } return err }) } func (ds *Datastore) UnlockHostManually(ctx context.Context, hostID uint, hostFleetPlatform string, ts time.Time) error { const stmt = ` INSERT INTO host_mdm_actions ( host_id, unlock_ref, fleet_platform ) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE -- do not overwrite if a value is already set unlock_ref = IF(unlock_ref IS NULL, VALUES(unlock_ref), unlock_ref) ` // for macOS, the unlock_ref is just the timestamp at which the user first // requested to unlock the host. This then indicates in the host's status // that it's pending an unlock (which requires manual intervention by // entering a PIN on the device). The /unlock endpoint can be called multiple // times, so we record the timestamp of the first time it was requested and // from then on, the host is marked as "pending unlock" until the device is // actually unlocked with the PIN. The actual unlocking happens when the // device sends an Idle MDM request. unlockRef := ts.Format(time.DateTime) _, err := ds.writer(ctx).ExecContext(ctx, stmt, hostID, unlockRef, hostFleetPlatform) return ctxerr.Wrap(ctx, err, "record manual unlock host request") } func buildHostLockWipeStatusUpdateStmt(refCol string, succeeded bool, joinPart string, setUnlockRef bool) string { var alias string stmt := `UPDATE host_mdm_actions ` if joinPart != "" { stmt += `hma ` + joinPart alias = "hma." } stmt += ` SET ` if succeeded { switch refCol { case "lock_ref": // Note that this must not clear the unlock_pin, because recording the // lock request does generate the PIN and store it there to be used by an // eventual unlock. if !setUnlockRef { stmt += fmt.Sprintf("%sunlock_ref = NULL, %[1]swipe_ref = NULL", alias) } else { // Currently only used for Apple MDM devices. // We set the unlock_ref to current time since the device can be unlocked any time after the lock. // Apple MDM does not have a concept of unlock pending. stmt += fmt.Sprintf("%sunlock_ref = '%s', %[1]swipe_ref = NULL", alias, time.Now().Format(time.DateTime)) } case "unlock_ref": // a successful unlock clears itself as well as the lock ref, because // unlock is the default state so we don't need to keep its unlock_ref // around once it's confirmed. stmt += fmt.Sprintf("%slock_ref = NULL, %[1]sunlock_ref = NULL, %[1]sunlock_pin = NULL, %[1]swipe_ref = NULL", alias) case "wipe_ref": stmt += fmt.Sprintf("%slock_ref = NULL, %[1]sunlock_ref = NULL, %[1]sunlock_pin = NULL", alias) } } else { // if the action failed, then we clear the reference to that action itself so // the host stays in the previous state (it doesn't transition to the new // state). stmt += fmt.Sprintf("%s"+refCol+" = NULL", alias) } return stmt } func (ds *Datastore) UpdateHostLockWipeStatusFromAppleMDMResult(ctx context.Context, hostUUID, cmdUUID, requestType string, succeeded bool) error { // a bit of MDM protocol leaking in the mysql layer, but it's either that or // the other way around (MDM protocol would translate to database column) var refCol string var setUnlockRef bool switch requestType { case "EraseDevice": refCol = "wipe_ref" case "DeviceLock": refCol = "lock_ref" setUnlockRef = true default: return nil } return updateHostLockWipeStatusFromResultAndHostUUID(ctx, ds.writer(ctx), hostUUID, refCol, cmdUUID, succeeded, setUnlockRef) } func updateHostLockWipeStatusFromResultAndHostUUID( ctx context.Context, tx sqlx.ExtContext, hostUUID, refCol, cmdUUID string, succeeded bool, setUnlockRef bool, ) error { stmt := buildHostLockWipeStatusUpdateStmt(refCol, succeeded, `JOIN hosts h ON hma.host_id = h.id`, setUnlockRef) stmt += ` WHERE h.uuid = ? AND hma.` + refCol + ` = ?` _, err := tx.ExecContext(ctx, stmt, hostUUID, cmdUUID) return ctxerr.Wrap(ctx, err, "update host lock/wipe status from result via host uuid") } func updateHostLockWipeStatusFromResult(ctx context.Context, tx sqlx.ExtContext, hostID uint, refCol string, succeeded bool) error { stmt := buildHostLockWipeStatusUpdateStmt(refCol, succeeded, "", false) stmt += ` WHERE host_id = ?` _, err := tx.ExecContext(ctx, stmt, hostID) return ctxerr.Wrap(ctx, err, "update host lock/wipe status from result") } func (ds *Datastore) updateUninstallStatusFromResult(ctx context.Context, tx sqlx.ExtContext, hostID uint, executionID string, exitCode int) error { stmt := ` UPDATE host_software_installs SET uninstall_script_exit_code = ? WHERE execution_id = ? AND host_id = ? ` if _, err := tx.ExecContext(ctx, stmt, exitCode, executionID, hostID); err != nil { return ctxerr.Wrap(ctx, err, "update uninstall status from result") } // NOTE: no need to call activateNextUpcomingActivity here as this function // is called from SetHostScriptExecutionResult which will call it before // completing. return nil } func (ds *Datastore) CleanupUnusedScriptContents(ctx context.Context) error { deleteStmt := ` DELETE FROM script_contents WHERE NOT EXISTS ( SELECT 1 FROM host_script_results WHERE script_content_id = script_contents.id) AND NOT EXISTS ( SELECT 1 FROM scripts WHERE script_content_id = script_contents.id) AND NOT EXISTS ( SELECT 1 FROM software_installers si WHERE script_contents.id IN (si.install_script_content_id, si.post_install_script_content_id, si.uninstall_script_content_id) ) AND NOT EXISTS ( SELECT 1 FROM setup_experience_scripts WHERE script_content_id = script_contents.id ) AND NOT EXISTS ( SELECT 1 FROM script_upcoming_activities WHERE script_content_id = script_contents.id ) ` _, err := ds.writer(ctx).ExecContext(ctx, deleteStmt) if err != nil { return ctxerr.Wrap(ctx, err, "cleaning up unused script contents") } return nil } func (ds *Datastore) getOrGenerateScriptContentsID(ctx context.Context, contents string) (uint, error) { csum := md5ChecksumScriptContent(contents) scriptContentsID, err := ds.optimisticGetOrInsert(ctx, ¶meterizedStmt{ Statement: `SELECT id FROM script_contents WHERE md5_checksum = UNHEX(?)`, Args: []interface{}{csum}, }, ¶meterizedStmt{ Statement: `INSERT INTO script_contents (md5_checksum, contents) VALUES (UNHEX(?), ?)`, Args: []interface{}{csum, contents}, }, ) if err != nil { return 0, err } return scriptContentsID, nil } func teamIDEq(teamID1, teamID2 *uint) bool { sameTeamNoTeam := teamID1 == nil && teamID2 == nil sameTeamNumber := teamID1 != nil && teamID2 != nil && *teamID1 == *teamID2 return sameTeamNoTeam || sameTeamNumber } func (ds *Datastore) batchExecuteScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint, batchExecID string) error { script, err := ds.Script(ctx, scriptID) if err != nil { return fleet.NewInvalidArgumentError("script_id", err.Error()) } invalidHostIDPlatform := "batch-invalid-hostid" // We need full host info to check if hosts are able to run scripts, see svc.RunHostScript fullHosts := make([]*fleet.Host, 0, len(hostIDs)) // The execution results to be stored in the database executions := make([]fleet.BatchExecutionHost, 0, len(fullHosts)) // Check that all hosts exist before attempting to process them for _, hostID := range hostIDs { host, err := ds.Host(ctx, hostID) if err != nil { fullHosts = append(fullHosts, &fleet.Host{ ID: hostID, Platform: invalidHostIDPlatform, }) continue } fullHosts = append(fullHosts, host) } if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { for _, host := range fullHosts { // Host doesn't exist anymore if host.Platform == invalidHostIDPlatform { executions = append(executions, fleet.BatchExecutionHost{ HostID: host.ID, Error: &fleet.BatchExecuteInvalidHost, }) continue } // Non-orbit-enrolled host (iOS, android) noNodeKey := host.OrbitNodeKey == nil || *host.OrbitNodeKey == "" // Scripts disabled on host scriptsDisabled := host.ScriptsEnabled != nil && !*host.ScriptsEnabled if noNodeKey || scriptsDisabled { executions = append(executions, fleet.BatchExecutionHost{ HostID: host.ID, Error: &fleet.BatchExecuteIncompatibleFleetd, }) continue } if !fleet.ValidateScriptPlatform(script.Name, host.Platform) { executions = append(executions, fleet.BatchExecutionHost{ HostID: host.ID, Error: &fleet.BatchExecuteIncompatiblePlatform, }) continue } executionID, _, err := ds.insertNewHostScriptExecution(ctx, tx, &fleet.HostScriptRequestPayload{ HostID: host.ID, UserID: userID, ScriptID: &script.ID, ScriptContentID: script.ScriptContentID, }, false) if err != nil { return ctxerr.Wrap(ctx, err, "queueing script for bulk execution") } executions = append(executions, fleet.BatchExecutionHost{ HostID: host.ID, ExecutionID: &executionID, }) } _, err := tx.ExecContext( ctx, `INSERT INTO batch_activities (execution_id, script_id, status, activity_type, num_targeted, started_at) VALUES (?, ?, ?, ?, ?, NOW()) ON DUPLICATE KEY UPDATE status = VALUES(status), started_at = VALUES(started_at)`, batchExecID, script.ID, fleet.ScheduledBatchExecutionStarted, fleet.BatchExecutionActivityScript, len(hostIDs), ) if err != nil { return ctxerr.Wrap(ctx, err, "failed to insert new batch execution") } args := make([]map[string]any, 0, len(executions)) for _, execHost := range executions { args = append(args, map[string]any{ "batch_id": batchExecID, "host_id": execHost.HostID, "host_execution_id": execHost.ExecutionID, "error": execHost.Error, }) } insertStmt := ` INSERT INTO batch_activity_host_results ( batch_execution_id, host_id, host_execution_id, error ) VALUES ( :batch_id, :host_id, :host_execution_id, :error ) ON DUPLICATE KEY UPDATE host_execution_id = VALUES(host_execution_id), error = VALUES(error)` if _, err := sqlx.NamedExecContext(ctx, tx, insertStmt, args); err != nil { return ctxerr.Wrap(ctx, err, "associating script executions with batch job") } return nil }); err != nil { return fmt.Errorf("creating bulk execution order: %w", err) } return nil } func (ds *Datastore) BatchExecuteScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint) (string, error) { batchExecID := uuid.New().String() script, err := ds.Script(ctx, scriptID) if err != nil { return "", fleet.NewInvalidArgumentError("script_id", err.Error()) } for _, hostID := range hostIDs { host, err := ds.HostLite(ctx, hostID) if err != nil { return "", fmt.Errorf("unable to load host information for %d: %w", hostID, err) } if !teamIDEq(host.TeamID, script.TeamID) { return "", ctxerr.Errorf(ctx, "all hosts must be on the same team as the script") } } if err := ds.batchExecuteScript(ctx, userID, scriptID, hostIDs, batchExecID); err != nil { return "", ctxerr.Wrap(ctx, err, "immediate batch execution") } return batchExecID, nil } func (ds *Datastore) BatchScheduleScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint, notBefore time.Time) (string, error) { batchExecID := uuid.New().String() const batchActivitiesStmt = `INSERT INTO batch_activities (execution_id, job_id, script_id, user_id, status, activity_type, num_targeted) VALUES (?, ?, ?, ?, ?, ?, ?)` const batchHostsStmt = `INSERT INTO batch_activity_host_results (batch_execution_id, host_id) VALUES (:exec_id, :host_id)` argBytes, err := json.Marshal(fleet.BatchActivityScriptJobArgs{ ExecutionID: batchExecID, }) if err != nil { return "", ctxerr.Wrap(ctx, err, "encooding job args") } if err := ds.withTx(ctx, func(tx sqlx.ExtContext) error { job, err := ds.NewJob(ctx, &fleet.Job{ Name: fleet.BatchActivityScriptsJobName, Args: (*json.RawMessage)(&argBytes), State: fleet.JobStateQueued, NotBefore: notBefore.UTC(), }) if err != nil { return ctxerr.Wrap(ctx, err, "creating new job") } _, err = tx.ExecContext( ctx, batchActivitiesStmt, batchExecID, job.ID, scriptID, userID, fleet.ScheduledBatchExecutionScheduled, fleet.BatchExecutionActivityScript, len(hostIDs), ) if err != nil { return ctxerr.Wrap(ctx, err, "inserting new batch activity") } args := make([]map[string]any, 0, len(hostIDs)) for _, hostID := range hostIDs { args = append(args, map[string]any{ "exec_id": batchExecID, "host_id": hostID, }) } if _, err := sqlx.NamedExecContext(ctx, tx, batchHostsStmt, args); err != nil { return ctxerr.Wrap(ctx, err, "inserting batch host results") } return nil }); err != nil { return "", ctxerr.Wrap(ctx, err, "creating scheduled script execution") } return batchExecID, nil } func (ds *Datastore) CancelBatchScript(ctx context.Context, executionID string) error { stmt := ` SELECT bahr.host_execution_id, bahr.host_id FROM batch_activity_host_results bahr LEFT JOIN host_script_results hsr ON bahr.host_execution_id = hsr.execution_id -- I think? WHERE bahr.batch_execution_id = ? AND hsr.canceled = 0 AND hsr.exit_code IS NULL AND bahr.error IS NULL` stmtSetCanceled := ` UPDATE batch_activities ba SET finished_at = NOW(), status = 'finished', canceled = 1, num_canceled = (SELECT COUNT(*) FROM batch_activity_host_results WHERE batch_execution_id = ba.execution_id) WHERE ba.execution_id = ?` stmtCanceled := ` UPDATE batch_activities SET canceled = 1 WHERE execution_id = ?` activity, err := ds.GetBatchActivity(ctx, executionID) if err != nil { return ctxerr.Wrap(ctx, err, "getting batch activity") } if activity.Status == fleet.ScheduledBatchExecutionFinished { return nil } if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { // If job worker exists, mark it as complete to stop it from running if jobID := activity.JobID; jobID != nil { job, err := ds.GetJob(ctx, *jobID) if err != nil { return ctxerr.Wrap(ctx, err, "failed to find job associated with batch activity") } job.State = fleet.JobStateSuccess if _, err := ds.updateJob(ctx, tx, *jobID, job); err != nil { return ctxerr.Wrap(ctx, err, "updating batch activity job") } } if activity.Status == fleet.ScheduledBatchExecutionStarted { // If the batch activity has started, we need to cancel anything in progress or queued toCancel := []struct { HostExecutionID string `db:"host_execution_id"` HostID uint `db:"host_id"` }{} if err := sqlx.SelectContext(ctx, tx, &toCancel, stmt, executionID); err != nil { return ctxerr.Wrap(ctx, err, "selecting hosts to cancel") } for _, host := range toCancel { if _, err := ds.cancelHostUpcomingActivity(ctx, tx, host.HostID, host.HostExecutionID); err != nil { return ctxerr.Wrap(ctx, err, "canceling upcoming activity") } } if _, err := tx.ExecContext(ctx, stmtCanceled, executionID); err != nil { return ctxerr.Wrap(ctx, err, "setting canceled column") } if err := ds.markActivitiesAsCompleted(ctx, tx); err != nil { return ctxerr.Wrap(ctx, err, "marking job as complete and summarizing counts") } } else { // The batch activity is scheduled, but not started if _, err := tx.ExecContext(ctx, stmtSetCanceled, executionID); err != nil { return ctxerr.Wrap(ctx, err, "setting canceled host count") } } return nil }); err != nil { return ctxerr.Wrap(ctx, err, "cancel batch script db transaction") } return nil } func (ds *Datastore) GetBatchActivity(ctx context.Context, executionID string) (*fleet.BatchActivity, error) { const stmt = ` SELECT ba.id, ba.script_id, s.name as script_name, ba.execution_id, ba.user_id, ba.job_id, ba.status, ba.activity_type, ba.num_targeted, ba.num_pending, ba.num_ran, ba.num_errored, ba.num_incompatible, ba.num_canceled, ba.created_at, ba.updated_at, ba.started_at, ba.finished_at, ba.canceled FROM batch_activities ba LEFT JOIN scripts s ON s.id = ba.script_id WHERE execution_id = ?` batchActivity := &fleet.BatchActivity{} if err := sqlx.GetContext(ctx, ds.reader(ctx), batchActivity, stmt, executionID); err != nil { return nil, ctxerr.Wrap(ctx, err, "selecting batch activity") } return batchActivity, nil } func (ds *Datastore) GetBatchActivityHostResults(ctx context.Context, executionID string) ([]*fleet.BatchActivityHostResult, error) { const stmt = ` SELECT id, batch_execution_id, host_id, host_execution_id, error FROM batch_activity_host_results WHERE batch_execution_id = ?` results := []*fleet.BatchActivityHostResult{} if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, stmt, executionID); err != nil { return nil, ctxerr.Wrap(ctx, err, "selecting batch activity host results") } return results, nil } func (ds *Datastore) RunScheduledBatchActivity(ctx context.Context, executionID string) error { batchActivity, err := ds.GetBatchActivity(ctx, executionID) if err != nil { return ctxerr.Wrap(ctx, err, "getting batch activity") } if batchActivity.Status != fleet.ScheduledBatchExecutionScheduled { return ctxerr.New(ctx, "batch job has already been started") } if batchActivity.Canceled { return ctxerr.New(ctx, "batch job was canceled") } if batchActivity.ScriptID == nil { return ctxerr.New(ctx, "no script ID present in batch activity") } script, err := ds.Script(ctx, *batchActivity.ScriptID) if err != nil { return ctxerr.Wrap(ctx, err, "could not get script") } results, err := ds.GetBatchActivityHostResults(ctx, executionID) if err != nil { return ctxerr.Wrap(ctx, err, "getting batch activity host results") } hostIDs := []uint{} for _, result := range results { hostIDs = append(hostIDs, result.HostID) } if err := ds.batchExecuteScript(ctx, batchActivity.UserID, script.ID, hostIDs, batchActivity.BatchExecutionID); err != nil { return ctxerr.Wrap(ctx, err, "scheduled batch script execution") } return nil } // Deprecated; will be removed in favor of ListBatchScriptExecutions when the batch script details page is ready. func (ds *Datastore) BatchExecuteSummary(ctx context.Context, executionID string) (*fleet.BatchActivity, error) { stmtExecutions := ` SELECT COUNT(*) as num_targeted, COUNT(bsehr.error) as num_did_not_run, COUNT(CASE WHEN hsr.exit_code = 0 THEN 1 END) as num_succeeded, COUNT(CASE WHEN hsr.exit_code <> 0 THEN 1 END) as num_failed, COUNT(CASE WHEN hsr.canceled = 1 AND hsr.exit_code IS NULL THEN 1 END) as num_cancelled FROM batch_activity_host_results bsehr LEFT JOIN host_script_results hsr ON bsehr.host_execution_id = hsr.execution_id WHERE bsehr.batch_execution_id = ?` stmtScriptDetails := ` SELECT script_id, s.name as script_name, s.team_id as team_id, bse.created_at as created_at FROM batch_activities bse JOIN scripts s ON bse.script_id = s.id WHERE bse.execution_id = ?` var summary fleet.BatchActivity var temp_summary struct { NumTargeted uint `db:"num_targeted"` NumDidNotRun uint `db:"num_did_not_run"` NumSucceeded uint `db:"num_succeeded"` NumFailed uint `db:"num_failed"` NumCancelled uint `db:"num_cancelled"` } // Fill out the execution details if err := sqlx.GetContext(ctx, ds.reader(ctx), &temp_summary, stmtExecutions, executionID); err != nil { return nil, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary") } summary.NumTargeted = &temp_summary.NumTargeted // NumRan is the number of hosts that actually ran the script successfully. summary.NumRan = &temp_summary.NumSucceeded // NumErrored is the number of hosts that errored out, which includes // both failed and did not run. summary.NumErrored = ptr.Uint(temp_summary.NumFailed + temp_summary.NumDidNotRun) // NumFailed is the number of hosts that were canceled before execution. summary.NumCanceled = &temp_summary.NumCancelled // NumPending is the number of hosts that are pending execution. summary.NumPending = ptr.Uint(temp_summary.NumTargeted - (temp_summary.NumSucceeded + temp_summary.NumFailed + temp_summary.NumDidNotRun + temp_summary.NumCancelled)) // Fill out the script details if err := sqlx.GetContext(ctx, ds.reader(ctx), &summary, stmtScriptDetails, executionID); err != nil { return nil, ctxerr.Wrap(ctx, err, "selecting script information for bulk execution summary") } if summary.TeamID == nil { summary.TeamID = ptr.Uint(0) } return &summary, nil } func (ds *Datastore) ListBatchScriptExecutions(ctx context.Context, filter fleet.BatchExecutionStatusFilter) ([]fleet.BatchActivity, error) { stmtExecutions := ` SELECT * FROM ( -- If batch is finished, get the cached host result counts SELECT COALESCE(ba.num_targeted, 0) AS num_targeted, COALESCE(ba.num_incompatible, 0) AS num_incompatible, COALESCE(ba.num_ran, 0) AS num_ran, COALESCE(ba.num_errored, 0) AS num_errored, COALESCE(ba.num_canceled, 0) AS num_canceled, COALESCE(ba.num_pending, 0) AS num_pending, ba.execution_id, ba.script_id, ba.status, ba.canceled, ba.finished_at, ba.started_at, s.name AS script_name, s.global_or_team_id AS team_id, ba.created_at AS created_at, j.not_before AS not_before, ba.id AS id FROM batch_activities ba JOIN scripts s ON ba.script_id = s.id LEFT JOIN jobs j ON j.id = ba.job_id WHERE ( %s ) AND ba.status = 'finished' UNION ALL -- If batch is not finished, calculate the host result counts live. SELECT COUNT(bahr.host_id) AS num_targeted, COUNT(bahr.error) AS num_incompatible, COUNT(IF(hsr.exit_code = 0, 1, NULL)) AS num_ran, COUNT(IF(hsr.exit_code <> 0, 1, NULL)) AS num_errored, COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba.canceled = 1), 1, NULL)) AS num_cancelled, ( COUNT(bahr.host_id) - COUNT(bahr.error) - COUNT(IF(hsr.exit_code = 0, 1, NULL)) - COUNT(IF(hsr.exit_code <> 0, 1, NULL)) - COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba.canceled = 1), 1, NULL)) ) AS num_pending, ba.execution_id, ba.script_id, ba.status, ba.canceled, ba.finished_at, ba.started_at, s.name AS script_name, s.global_or_team_id AS team_id, ba.created_at AS created_at, j.not_before AS not_before, ba.id AS id FROM batch_activities ba LEFT JOIN batch_activity_host_results bahr ON ba.execution_id = bahr.batch_execution_id LEFT JOIN host_script_results hsr ON bahr.host_execution_id = hsr.execution_id JOIN scripts s ON ba.script_id = s.id LEFT JOIN jobs j ON j.id = ba.job_id WHERE ( %s ) AND ba.status <> 'finished' GROUP BY ba.id ) AS u ORDER BY %s LIMIT %d OFFSET %d ` limit := 10 offset := 0 args := []any{} orderBy := []string{"u.created_at DESC", "u.id DESC"} whereClauses := make([]string, 0, 2) // If an execution ID is provided, use it to filter the results. if filter.ExecutionID != nil && *filter.ExecutionID != "" { whereClauses = append(whereClauses, "ba.execution_id = ?") args = append(args, *filter.ExecutionID) } else { // Otherwise filter by status and/or team ID. if filter.Status != nil && *filter.Status != "" { whereClauses = append(whereClauses, "ba.status = ?") args = append(args, *filter.Status) switch *filter.Status { case string(fleet.ScheduledBatchExecutionScheduled): orderBy = append([]string{"u.not_before ASC"}, orderBy...) case string(fleet.ScheduledBatchExecutionStarted): orderBy = append([]string{"u.started_at DESC"}, orderBy...) case string(fleet.ScheduledBatchExecutionFinished): orderBy = append([]string{"u.finished_at DESC"}, orderBy...) default: // no additional ordering } } if filter.TeamID != nil { whereClauses = append(whereClauses, "s.global_or_team_id = ?") args = append(args, *filter.TeamID) } } // Double up the args to use them in both WHERE clauses. args = append(args, args...) // Use pagination parameters if provided. if filter.Limit != nil { limit = int(*filter.Limit) //nolint:gosec // dismiss G115 } if filter.Offset != nil { offset = int(*filter.Offset) //nolint:gosec // dismiss G115 } where := strings.Join(whereClauses, " AND ") stmtExecutions = fmt.Sprintf(stmtExecutions, where, where, strings.Join(orderBy, ", "), limit, offset) var summary []fleet.BatchActivity if err := sqlx.SelectContext(ctx, ds.reader(ctx), &summary, stmtExecutions, args...); err != nil { return nil, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary") } return summary, nil } func (ds *Datastore) CountBatchScriptExecutions(ctx context.Context, filter fleet.BatchExecutionStatusFilter) (int64, error) { stmtExecutions := ` SELECT COUNT(*) FROM batch_activities ba JOIN scripts s ON ba.script_id = s.id WHERE %s ` args := []any{} whereClauses := make([]string, 0, 2) if filter.Status != nil && *filter.Status != "" { whereClauses = append(whereClauses, "ba.status = ?") args = append(args, *filter.Status) } if filter.TeamID != nil { whereClauses = append(whereClauses, "s.global_or_team_id = ?") args = append(args, *filter.TeamID) } where := strings.Join(whereClauses, " AND ") stmtExecutions = fmt.Sprintf(stmtExecutions, where) var count int64 if err := sqlx.GetContext(ctx, ds.reader(ctx), &count, stmtExecutions, args...); err != nil { return 0, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary") } return count, nil } func (ds *Datastore) markActivitiesAsCompleted(ctx context.Context, tx sqlx.ExtContext) error { const stmt = ` UPDATE batch_activities AS ba JOIN ( SELECT ba2.id AS batch_id, COUNT(bahr.host_id) AS num_targeted, COUNT(bahr.error) AS num_incompatible, COUNT(IF(hsr.exit_code = 0, 1, NULL)) AS num_ran, COUNT(IF(hsr.exit_code <> 0, 1, NULL)) AS num_errored, COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba2.canceled = 1), 1, NULL)) AS num_canceled FROM batch_activities AS ba2 LEFT JOIN batch_activity_host_results AS bahr ON ba2.execution_id = bahr.batch_execution_id LEFT JOIN host_script_results AS hsr ON bahr.host_execution_id = hsr.execution_id WHERE ba2.status = 'started' GROUP BY ba2.id HAVING (num_incompatible + num_ran + num_errored + num_canceled) >= num_targeted ) AS agg ON agg.batch_id = ba.id SET ba.status = 'finished', ba.finished_at = NOW(), ba.num_targeted = agg.num_targeted, ba.num_incompatible = agg.num_incompatible, ba.num_ran = agg.num_ran, ba.num_errored = agg.num_errored, ba.num_canceled = agg.num_canceled, ba.num_pending = 0 WHERE ba.status = 'started'; ` // TODO -- use `RETURNING` to return the IDs of the updated activities? _, err := tx.ExecContext(ctx, stmt) if err != nil { return ctxerr.Wrap(ctx, err, "marking activities as completed") } return nil } func (ds *Datastore) MarkActivitiesAsCompleted(ctx context.Context) error { return ds.markActivitiesAsCompleted(ctx, ds.writer(ctx)) }