package mysql import ( "context" "crypto/md5" //nolint:gosec "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "strings" "time" "unicode/utf8" constants "github.com/fleetdm/fleet/v4/pkg/scripts" "github.com/fleetdm/fleet/v4/server/contexts/ctxerr" "github.com/fleetdm/fleet/v4/server/fleet" common_mysql "github.com/fleetdm/fleet/v4/server/platform/mysql" "github.com/fleetdm/fleet/v4/server/ptr" "github.com/google/uuid" "github.com/jmoiron/sqlx" ) var scriptsAllowedOrderKeys = common_mysql.OrderKeyAllowlist{ "id": "s.id", "name": "s.name", "created_at": "s.created_at", "updated_at": "s.updated_at", } // hostScriptDetailsAllowedOrderKeys is intentionally minimal: the service layer // pins OrderKey to "name" before reaching this datastore method. var hostScriptDetailsAllowedOrderKeys = common_mysql.OrderKeyAllowlist{ "name": "s.name", } func (ds *Datastore) NewHostScriptExecutionRequest(ctx context.Context, request *fleet.HostScriptRequestPayload) (*fleet.HostScriptResult, error) { var res *fleet.HostScriptResult return res, ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var err error if request.ScriptContentID == 0 { // then we are doing a sync execution, so create the contents first scRes, err := insertScriptContents(ctx, tx, request.ScriptContents) if err != nil { return err } id, _ := scRes.LastInsertId() request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115 } res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, false) return err }) } func (ds *Datastore) newHostScriptExecutionRequest(ctx context.Context, tx sqlx.ExtContext, request *fleet.HostScriptRequestPayload, isInternal bool) (*fleet.HostScriptResult, error) { const ( getStmt = ` SELECT ua.id, ua.host_id, ua.execution_id, ua.created_at, sua.script_id, sua.policy_id, ua.user_id, payload->'$.sync_request' AS sync_request, sc.contents as script_contents, sua.setup_experience_script_id FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id INNER JOIN script_contents sc ON sua.script_content_id = sc.id WHERE ua.id = ? ` ) _, activityID, err := ds.insertNewHostScriptExecution(ctx, tx, request, isInternal) if err != nil { return nil, ctxerr.Wrap(ctx, err, "inserting new script execution request") } var script fleet.HostScriptResult err = sqlx.GetContext(ctx, tx, &script, getStmt, activityID) if err != nil { return nil, ctxerr.Wrap(ctx, err, "getting the created host script activity to return") } return &script, nil } func (ds *Datastore) insertNewHostScriptExecution(ctx context.Context, tx sqlx.ExtContext, request *fleet.HostScriptRequestPayload, isInternal bool) (string, int64, error) { const ( insUAStmt = ` INSERT INTO upcoming_activities (host_id, priority, user_id, fleet_initiated, activity_type, execution_id, payload) VALUES (?, ?, ?, ?, 'script', ?, JSON_OBJECT( 'sync_request', ?, 'is_internal', ?, 'user', (SELECT JSON_OBJECT('name', name, 'email', email, 'gravatar_url', gravatar_url) FROM users WHERE id = ?) ) )` insSUAStmt = ` INSERT INTO script_upcoming_activities (upcoming_activity_id, script_id, script_content_id, policy_id, setup_experience_script_id) VALUES (?, ?, ?, ?, ?) ` ) execID := uuid.New().String() result, err := tx.ExecContext(ctx, insUAStmt, request.HostID, request.Priority(), request.UserID, request.PolicyID != nil, // fleet-initiated if request is via a policy failure execID, request.SyncRequest, isInternal, request.UserID, ) if err != nil { return "", 0, ctxerr.Wrap(ctx, err, "new script upcoming activity") } activityID, _ := result.LastInsertId() _, err = tx.ExecContext(ctx, insSUAStmt, activityID, request.ScriptID, request.ScriptContentID, request.PolicyID, request.SetupExperienceScriptID, ) if err != nil { return "", 0, ctxerr.Wrap(ctx, err, "new join script upcoming activity") } if _, err := ds.activateNextUpcomingActivity(ctx, tx, request.HostID, ""); err != nil { return "", 0, ctxerr.Wrap(ctx, err, "activate next activity") } return execID, activityID, nil } func truncateScriptResult(output string) string { const maxOutputRuneLen = 10000 if len(output) > utf8.UTFMax*maxOutputRuneLen { // truncate the bytes as we know the output is too long, no point // converting more bytes than needed to runes. output = output[len(output)-(utf8.UTFMax*maxOutputRuneLen):] } if utf8.RuneCountInString(output) > maxOutputRuneLen { outputRunes := []rune(output) output = string(outputRunes[len(outputRunes)-maxOutputRuneLen:]) } return output } func (ds *Datastore) SetHostScriptExecutionResult(ctx context.Context, result *fleet.HostScriptResultPayload, attemptNumber *int) (*fleet.HostScriptResult, string, error, ) { const resultExistsStmt = ` SELECT 1 FROM host_script_results WHERE host_id = ? AND execution_id = ? AND exit_code IS NOT NULL ` const updStmt = ` UPDATE host_script_results SET output = ?, runtime = ?, exit_code = ?, timeout = ?, attempt_number = ? WHERE host_id = ? AND execution_id = ?` const hostMDMActionsStmt = ` SELECT 'uninstall' AS action FROM host_software_installs WHERE execution_id = :execution_id AND host_id = :host_id UNION -- host_mdm_actions query (and thus row in union) must be last to avoid #25144 SELECT CASE WHEN lock_ref = :execution_id THEN 'lock_ref' WHEN unlock_ref = :execution_id THEN 'unlock_ref' WHEN wipe_ref = :execution_id THEN 'wipe_ref' ELSE '' END AS action FROM host_mdm_actions WHERE host_id = :host_id ` output := truncateScriptResult(result.Output) var hsr *fleet.HostScriptResult var action string err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var resultExists bool err := sqlx.GetContext(ctx, tx, &resultExists, resultExistsStmt, result.HostID, result.ExecutionID) if err != nil && !errors.Is(err, sql.ErrNoRows) { return ctxerr.Wrap(ctx, err, "check if host script result exists") } if resultExists { ds.logger.DebugContext(ctx, "duplicate script execution result sent, will be ignored (original result is preserved)", "host_id", result.HostID, "execution_id", result.ExecutionID, ) // still do the activate next activity to ensure progress as there was // an unexpected flow if we get here. if _, err := ds.activateNextUpcomingActivity(ctx, tx, result.HostID, result.ExecutionID); err != nil { return ctxerr.Wrap(ctx, err, "activate next activity") } // succeed but leave hsr nil return nil } res, err := tx.ExecContext(ctx, updStmt, output, result.Runtime, // Windows error codes are signed 32-bit integers, but are // returned as unsigned integers by the windows API. The // software that receives them is responsible for casting // it to a 32-bit signed integer. // See /orbit/pkg/scripts/exec_windows.go int32(result.ExitCode), //nolint:gosec // dismiss G115 result.Timeout, attemptNumber, result.HostID, result.ExecutionID, ) if err != nil { return ctxerr.Wrap(ctx, err, "update host script result") } if n, _ := res.RowsAffected(); n > 0 { // it did update, so return the updated result hsr, err = ds.getHostScriptExecutionResultDB(ctx, tx, result.ExecutionID, scriptExecutionSearchOpts{IncludeCanceled: true}) if err != nil { return ctxerr.Wrap(ctx, err, "load updated host script result") } // look up if that script was a lock/unlock/wipe/uninstall script for that host, // and if so update the host_mdm_actions table accordingly. namedArgs := map[string]any{ "host_id": result.HostID, "execution_id": result.ExecutionID, } stmt, args, err := sqlx.Named(hostMDMActionsStmt, namedArgs) if err != nil { return ctxerr.Wrap(ctx, err, "build named query for host mdm actions") } err = sqlx.GetContext(ctx, tx, &action, stmt, args...) if err != nil && !errors.Is(err, sql.ErrNoRows) { // ignore ErrNoRows, refCol will be empty return ctxerr.Wrap(ctx, err, "lookup host script corresponding mdm action") } switch action { case "": // do nothing case "uninstall": err = ds.updateUninstallStatusFromResult(ctx, tx, result.HostID, result.ExecutionID, result.ExitCode) if err != nil { return ctxerr.Wrap(ctx, err, "update host uninstall action based on script result") } default: // lock/unlock/wipe err = updateHostLockWipeStatusFromResult(ctx, tx, result.HostID, action, result.ExitCode == 0) if err != nil { return ctxerr.Wrap(ctx, err, "update host mdm action based on script result") } } } if _, err := ds.activateNextUpcomingActivity(ctx, tx, result.HostID, result.ExecutionID); err != nil { return ctxerr.Wrap(ctx, err, "activate next activity") } return nil }) if err != nil { return nil, "", err } return hsr, action, nil } func (ds *Datastore) ListPendingHostScriptExecutions(ctx context.Context, hostID uint, onlyShowInternal bool) ([]*fleet.HostScriptResult, error) { return ds.listUpcomingHostScriptExecutions(ctx, hostID, onlyShowInternal, false) } func (ds *Datastore) ListReadyToExecuteScriptsForHost(ctx context.Context, hostID uint, onlyShowInternal bool) ([]*fleet.HostScriptResult, error) { return ds.listUpcomingHostScriptExecutions(ctx, hostID, onlyShowInternal, true) } func (ds *Datastore) listUpcomingHostScriptExecutions(ctx context.Context, hostID uint, onlyShowInternal, onlyReadyToExecute bool) ([]*fleet.HostScriptResult, error) { extraWhere := "" if onlyShowInternal { // software_uninstalls are implicitly internal extraWhere = " AND COALESCE(ua.payload->'$.is_internal', 1) = 1" } if onlyReadyToExecute { extraWhere += " AND ua.activated_at IS NOT NULL" } // this selects software uninstalls too as they run as scripts listStmt := fmt.Sprintf(` SELECT id, host_id, execution_id, script_id, created_at FROM ( SELECT ua.id, ua.host_id, ua.execution_id, sua.script_id, ua.priority, ua.created_at, IF(ua.activated_at IS NULL, 0, 1) AS topmost FROM upcoming_activities ua -- left join because software_uninstall has no script join LEFT JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.host_id = ? AND ua.activity_type IN ('script', 'software_uninstall') %s ORDER BY topmost DESC, priority DESC, created_at ASC) t`, extraWhere) var results []*fleet.HostScriptResult if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, listStmt, hostID); err != nil { return nil, ctxerr.Wrap(ctx, err, "list pending host script executions") } return results, nil } func (ds *Datastore) IsExecutionPendingForHost(ctx context.Context, hostID uint, scriptID uint) (bool, error) { const getStmt = ` SELECT 1 FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.host_id = ? AND ua.activity_type = 'script' AND sua.script_id = ? ` var results []*uint if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, getStmt, hostID, scriptID); err != nil { return false, ctxerr.Wrap(ctx, err, "is execution pending for host") } return len(results) > 0, nil } type scriptExecutionSearchOpts struct { IncludeCanceled bool UninstallHostID uint } func (ds *Datastore) GetHostScriptExecutionResult(ctx context.Context, execID string) (*fleet.HostScriptResult, error) { return ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), execID, scriptExecutionSearchOpts{}) } func (ds *Datastore) GetSelfServiceUninstallScriptExecutionResult(ctx context.Context, execID string, hostID uint) (*fleet.HostScriptResult, error) { return ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), execID, scriptExecutionSearchOpts{UninstallHostID: hostID}) } func (ds *Datastore) getHostScriptExecutionResultDB(ctx context.Context, q sqlx.QueryerContext, execID string, opts scriptExecutionSearchOpts) (*fleet.HostScriptResult, error) { var activeParams []any canceledCondition := "" if !opts.IncludeCanceled { canceledCondition = " AND hsr.canceled = 0" } uninstallCondition := "" if opts.UninstallHostID > 0 { uninstallCondition = `JOIN host_software_installs hsi ON hsi.execution_id = hsr.execution_id AND hsi.uninstall = TRUE AND hsr.host_id = ?` activeParams = append(activeParams, opts.UninstallHostID) } activeParams = append(activeParams, execID) getActiveStmt := fmt.Sprintf(` SELECT hsr.id, hsr.host_id, hsr.execution_id, sc.contents as script_contents, hsr.script_id, hsr.policy_id, hsr.output, hsr.runtime, hsr.exit_code, hsr.timeout, hsr.created_at, hsr.user_id, hsr.sync_request, hsr.host_deleted_at, hsr.setup_experience_script_id, hsr.canceled, bahr.batch_execution_id, hsr.attempt_number FROM host_script_results hsr LEFT JOIN batch_activity_host_results bahr ON hsr.execution_id = bahr.host_execution_id JOIN script_contents sc %s WHERE hsr.execution_id = ? AND hsr.script_content_id = sc.id %s `, uninstallCondition, canceledCondition) // We don't include upcoming uninstall script executions in results (different activity type, and they're blank anyway) const getUpcomingStmt = ` SELECT 0 as id, ua.host_id, ua.execution_id, sc.contents as script_contents, sua.script_id, sua.policy_id, '' as output, 0 as runtime, NULL as exit_code, NULL as timeout, ua.created_at, ua.user_id, COALESCE(ua.payload->'$.sync_request', 0) as sync_request, NULL as host_deleted_at, sua.setup_experience_script_id, 0 as canceled, NULL as batch_execution_id, NULL as attempt_number FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id INNER JOIN script_contents sc ON sua.script_content_id = sc.id WHERE ua.execution_id = ? AND ua.activity_type = 'script' ` var result fleet.HostScriptResult if err := sqlx.GetContext(ctx, q, &result, getActiveStmt, activeParams...); err != nil { if errors.Is(err, sql.ErrNoRows) { // try with upcoming activities err = sqlx.GetContext(ctx, q, &result, getUpcomingStmt, execID) if errors.Is(err, sql.ErrNoRows) { return nil, ctxerr.Wrap(ctx, notFound("HostScriptResult").WithName(execID)) } } if err != nil { return nil, ctxerr.Wrap(ctx, err, "get host script result") } } return &result, nil } func (ds *Datastore) CountHostScriptAttempts(ctx context.Context, hostID, scriptID, policyID uint) (int, error) { var count int // Only count attempts from the current retry sequence. // When a policy passes, all attempt_number values are reset to 0 to mark them as "old sequence". // We count attempts where attempt_number > 0 (current sequence) OR attempt_number IS NULL (currently being processed). err := sqlx.GetContext(ctx, ds.reader(ctx), &count, ` SELECT COUNT(*) FROM host_script_results WHERE host_id = ? AND script_id = ? AND policy_id = ? AND canceled = 0 AND (attempt_number > 0 OR attempt_number IS NULL) `, hostID, scriptID, policyID) if err != nil { return 0, ctxerr.Wrap(ctx, err, "count host script attempts") } return count, nil } func (ds *Datastore) NewScript(ctx context.Context, script *fleet.Script) (*fleet.Script, error) { var res sql.Result err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var err error // first insert script contents scRes, err := insertScriptContents(ctx, tx, script.ScriptContents) if err != nil { return err } id, _ := scRes.LastInsertId() // then create the script entity res, err = insertScript(ctx, tx, script, uint(id)) //nolint:gosec // dismiss G115 return err }) if err != nil { return nil, err } id, _ := res.LastInsertId() return ds.getScriptDB(ctx, ds.writer(ctx), uint(id)) //nolint:gosec // dismiss G115 } func (ds *Datastore) UpdateScriptContents(ctx context.Context, scriptID uint, scriptContents string) (*fleet.Script, error) { err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { // Get the current script_content_id var oldContentID int64 getCurrentStmt := `SELECT script_content_id FROM scripts WHERE id = ?` err := sqlx.GetContext(ctx, tx, &oldContentID, getCurrentStmt, scriptID) if err != nil { return ctxerr.Wrap(ctx, err, "getting current script content id") } // Insert or get existing content (insertScriptContents handles deduplication) scRes, err := insertScriptContents(ctx, tx, scriptContents) if err != nil { return ctxerr.Wrap(ctx, err, "inserting/getting script contents") } newContentID, _ := scRes.LastInsertId() // Update the script to point to the new content if newContentID != oldContentID { updateStmt := ` UPDATE scripts SET script_content_id = ? WHERE id = ? ` _, err = tx.ExecContext(ctx, updateStmt, newContentID, scriptID) if err != nil { return ctxerr.Wrap(ctx, err, "updating script content reference") } // Try to clean up the old content if no longer used // Don't fail the transaction if cleanup fails; just log it if err := ds.cleanupScriptContent(ctx, tx, uint(oldContentID)); err != nil { //nolint:gosec ds.logger.ErrorContext(ctx, "failed to cleanup orphaned script content", "script_id", scriptID, "old_content_id", oldContentID, "err", err) ctxerr.Handle(ctx, err) } } else { // Just update the timestamp _, err = tx.ExecContext(ctx, "UPDATE scripts SET updated_at = NOW() WHERE id = ?", scriptID) if err != nil { return ctxerr.Wrap(ctx, err, "updating script updated_at time") } } // Cancel pending executions if err := ds.cancelUpcomingScriptActivities(ctx, tx, scriptID); err != nil { return ctxerr.Wrap(ctx, err, "canceling upcoming script executions") } // When a script is modified reset attempt numbers for policy automations if err := ds.resetScriptPolicyAutomationAttempts(ctx, tx, scriptID); err != nil { return ctxerr.Wrap(ctx, err, "resetting policy automation attempts for script") } return nil }) if err != nil { return nil, ctxerr.Wrap(ctx, err, "updating script contents") } return ds.Script(ctx, scriptID) } func (ds *Datastore) cancelUpcomingScriptActivities(ctx context.Context, db sqlx.ExtContext, scriptID uint) error { const stmt = ` SELECT ua.execution_id, ua.host_id FROM script_upcoming_activities sua INNER JOIN upcoming_activities ua ON ua.id = sua.upcoming_activity_id WHERE sua.script_id = ? ` var upcomingExecutions []struct { ExecutionID string `db:"execution_id"` HostID uint `db:"host_id"` } if err := sqlx.SelectContext(ctx, db, &upcomingExecutions, stmt, scriptID); err != nil { return ctxerr.Wrap(ctx, err, "selecting upcoming script executions") } for _, upcomingExecution := range upcomingExecutions { if _, err := ds.cancelHostUpcomingActivity(ctx, db, upcomingExecution.HostID, upcomingExecution.ExecutionID, true); err != nil { return ctxerr.Wrap(ctx, err, "canceling upcoming activity") } } // Cancel scripts that were already activated and are in host_script_results but not yet executed const activatedStmt = `UPDATE host_script_results SET canceled = 1 WHERE script_id = ? AND exit_code IS NULL AND canceled = 0` if _, err := db.ExecContext(ctx, activatedStmt, scriptID); err != nil { return ctxerr.Wrap(ctx, err, "canceling activated pending script executions") } return nil } // resetScriptPolicyAutomationAttempts resets all attempt numbers for script executions for policy automations func (ds *Datastore) resetScriptPolicyAutomationAttempts(ctx context.Context, db sqlx.ExecerContext, scriptID uint) error { _, err := db.ExecContext(ctx, ` UPDATE host_script_results SET attempt_number = 0 WHERE script_id = ? AND policy_id IS NOT NULL AND (attempt_number > 0 OR attempt_number IS NULL) `, scriptID) if err != nil { return ctxerr.Wrap(ctx, err, "reset policy automation script attempts") } return nil } func insertScript(ctx context.Context, tx sqlx.ExtContext, script *fleet.Script, scriptContentsID uint) (sql.Result, error) { const insertStmt = ` INSERT INTO scripts ( team_id, global_or_team_id, name, script_content_id ) VALUES (?, ?, ?, ?) ` var globalOrTeamID uint if script.TeamID != nil { globalOrTeamID = *script.TeamID } res, err := tx.ExecContext(ctx, insertStmt, script.TeamID, globalOrTeamID, script.Name, scriptContentsID) if err != nil { if IsDuplicate(err) { // name already exists for this team/global err = alreadyExists("Script", script.Name) } else if isChildForeignKeyError(err) { // team does not exist err = foreignKey("scripts", fmt.Sprintf("team_id=%v", script.TeamID)) } return nil, ctxerr.Wrap(ctx, err, "insert script") } return res, nil } func insertScriptContents(ctx context.Context, tx sqlx.ExtContext, contents string) (sql.Result, error) { const insertStmt = ` INSERT INTO script_contents ( md5_checksum, contents ) VALUES (UNHEX(?),?) ON DUPLICATE KEY UPDATE id=LAST_INSERT_ID(id) ` md5Checksum := md5ChecksumScriptContent(contents) res, err := tx.ExecContext(ctx, insertStmt, md5Checksum, contents) if err != nil { return nil, ctxerr.Wrap(ctx, err, "insert script contents") } return res, nil } func md5ChecksumScriptContent(s string) string { return md5ChecksumBytes([]byte(s)) } func md5ChecksumBytes(b []byte) string { rawChecksum := md5.Sum(b) //nolint:gosec return strings.ToUpper(hex.EncodeToString(rawChecksum[:])) } func (ds *Datastore) cleanupScriptContent(ctx context.Context, tx sqlx.ExtContext, contentID uint) error { // Check if this content is still being used anywhere var usageCount int stmt := ` SELECT COUNT(*) FROM ( SELECT 1 FROM scripts WHERE script_content_id = ? UNION ALL SELECT 1 FROM setup_experience_scripts WHERE script_content_id = ? UNION ALL SELECT 1 FROM software_installers WHERE install_script_content_id = ? OR uninstall_script_content_id = ? OR post_install_script_content_id = ? UNION ALL SELECT 1 FROM script_upcoming_activities WHERE script_content_id = ? UNION ALL SELECT 1 FROM host_script_results WHERE script_content_id = ? ) t ` err := sqlx.GetContext(ctx, tx, &usageCount, stmt, contentID, contentID, contentID, contentID, contentID, contentID, contentID) if err != nil { return ctxerr.Wrap(ctx, err, "checking script content usage for cleanup") } if usageCount == 0 { // Not being used, safe to delete deleteStmt := `DELETE FROM script_contents WHERE id = ?` _, err = tx.ExecContext(ctx, deleteStmt, contentID) if err != nil { return ctxerr.Wrap(ctx, err, "deleting unused script content") } } return nil } func (ds *Datastore) Script(ctx context.Context, id uint) (*fleet.Script, error) { return ds.getScriptDB(ctx, ds.reader(ctx), id) } func (ds *Datastore) getScriptDB(ctx context.Context, q sqlx.QueryerContext, id uint) (*fleet.Script, error) { const getStmt = ` SELECT id, team_id, name, created_at, updated_at, script_content_id FROM scripts WHERE id = ? ` var script fleet.Script if err := sqlx.GetContext(ctx, q, &script, getStmt, id); err != nil { if err == sql.ErrNoRows { return nil, notFound("Script").WithID(id) } return nil, ctxerr.Wrap(ctx, err, "get script") } return &script, nil } func (ds *Datastore) GetScriptContents(ctx context.Context, id uint) ([]byte, error) { const getStmt = ` SELECT sc.contents FROM script_contents sc JOIN scripts s ON s.script_content_id = sc.id WHERE s.id = ? ` var contents []byte if err := sqlx.GetContext(ctx, ds.reader(ctx), &contents, getStmt, id); err != nil { if err == sql.ErrNoRows { return nil, notFound("Script").WithID(id) } return nil, ctxerr.Wrap(ctx, err, "get script contents") } return contents, nil } func (ds *Datastore) GetAnyScriptContents(ctx context.Context, id uint) ([]byte, error) { const getStmt = ` SELECT sc.contents FROM script_contents sc WHERE sc.id = ? ` var contents []byte if err := sqlx.GetContext(ctx, ds.reader(ctx), &contents, getStmt, id); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, notFound("Script").WithID(id) } return nil, ctxerr.Wrap(ctx, err, "get any script contents") } return contents, nil } var errDeleteScriptWithAssociatedPolicy = &fleet.ConflictError{Message: "Couldn't delete. Policy automation uses this script. Please remove this script from associated policy automations and try again."} func (ds *Datastore) DeleteScript(ctx context.Context, id uint) error { var activateAffectedHosts []uint err := ds.withTx(ctx, func(tx sqlx.ExtContext) error { _, err := tx.ExecContext(ctx, `DELETE FROM host_script_results WHERE script_id = ? AND exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND)`, id, int(constants.MaxServerWaitTime.Seconds()), ) if err != nil { return ctxerr.Wrapf(ctx, err, "cancel pending script executions") } // load hosts that will have their upcoming_activities deleted, if that // activity is "activated", as that means we will have to call // activateNextUpcomingActivity for those hosts. loadAffectedHostsStmt := ` SELECT DISTINCT host_id FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE sua.script_id = ? AND ua.activity_type = 'script' AND ua.activated_at IS NOT NULL AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND)` var affectedHosts []uint if err := sqlx.SelectContext(ctx, tx, &affectedHosts, loadAffectedHostsStmt, id, int(constants.MaxServerWaitTime.Seconds())); err != nil { return ctxerr.Wrapf(ctx, err, "load affected hosts") } activateAffectedHosts = affectedHosts _, err = tx.ExecContext(ctx, `DELETE FROM upcoming_activities USING upcoming_activities INNER JOIN script_upcoming_activities sua ON upcoming_activities.id = sua.upcoming_activity_id WHERE sua.script_id = ? AND upcoming_activities.activity_type = 'script' AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND) `, id, int(constants.MaxServerWaitTime.Seconds()), ) if err != nil { return ctxerr.Wrapf(ctx, err, "cancel upcoming pending script executions") } _, err = tx.ExecContext(ctx, `DELETE FROM scripts WHERE id = ?`, id) if err != nil { if isMySQLForeignKey(err) { // Check if the script is referenced by a policy automation. var count int if err := sqlx.GetContext(ctx, tx, &count, `SELECT COUNT(*) FROM policies WHERE script_id = ?`, id); err != nil { return ctxerr.Wrapf(ctx, err, "getting reference from policies") } if count > 0 { return ctxerr.Wrap(ctx, errDeleteScriptWithAssociatedPolicy, "delete script") } } return ctxerr.Wrap(ctx, err, "delete script") } return nil }) if err != nil { return err } // we call this outside of the transaction to avoid a // long-running/deadlock-prone transaction, as many hosts could be affected. return ds.activateNextUpcomingActivityForBatchOfHosts(ctx, activateAffectedHosts) } // deletePendingHostScriptExecutionsForPolicy should be called when a policy is deleted to remove any pending script executions func (ds *Datastore) deletePendingHostScriptExecutionsForPolicy(ctx context.Context, teamID *uint, policyID uint) error { var globalOrTeamID uint if teamID != nil { globalOrTeamID = *teamID } deletePendingFunc := func(stmt string, args ...any) error { _, err := ds.writer(ctx).ExecContext(ctx, stmt, args...) return ctxerr.Wrap(ctx, err, "delete pending host script executions for policy") } deleteHSRStmt := ` DELETE FROM host_script_results WHERE policy_id = ? AND script_id IN ( SELECT id FROM scripts WHERE scripts.global_or_team_id = ? ) AND exit_code IS NULL ` if err := deletePendingFunc(deleteHSRStmt, policyID, globalOrTeamID); err != nil { return err } loadAffectedHostsStmt := ` SELECT DISTINCT host_id FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.activity_type = 'script' AND ua.activated_at IS NOT NULL AND sua.policy_id = ? AND sua.script_id IN ( SELECT id FROM scripts WHERE scripts.global_or_team_id = ? )` var affectedHosts []uint if err := sqlx.SelectContext(ctx, ds.reader(ctx), &affectedHosts, loadAffectedHostsStmt, policyID, globalOrTeamID); err != nil { return err } deleteUAStmt := ` DELETE FROM upcoming_activities USING upcoming_activities INNER JOIN script_upcoming_activities sua ON upcoming_activities.id = sua.upcoming_activity_id WHERE upcoming_activities.activity_type = 'script' AND sua.policy_id = ? AND sua.script_id IN ( SELECT id FROM scripts WHERE scripts.global_or_team_id = ? ) ` if err := deletePendingFunc(deleteUAStmt, policyID, globalOrTeamID); err != nil { return err } return ds.activateNextUpcomingActivityForBatchOfHosts(ctx, affectedHosts) } func (ds *Datastore) ListScripts(ctx context.Context, teamID *uint, opt fleet.ListOptions) ([]*fleet.Script, *fleet.PaginationMetadata, error) { var scripts []*fleet.Script const selectStmt = ` SELECT s.id, s.team_id, s.name, s.created_at, s.updated_at FROM scripts s WHERE s.global_or_team_id = ? ` var globalOrTeamID uint if teamID != nil { globalOrTeamID = *teamID } args := []any{globalOrTeamID} stmt, args, err := appendListOptionsWithCursorToSQLSecure(selectStmt, args, &opt, scriptsAllowedOrderKeys) if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "list scripts") } if err := sqlx.SelectContext(ctx, ds.reader(ctx), &scripts, stmt, args...); err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "select scripts") } var metaData *fleet.PaginationMetadata if opt.IncludeMetadata { metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0} if len(scripts) > int(opt.PerPage) { //nolint:gosec // dismiss G115 metaData.HasNextResults = true scripts = scripts[:len(scripts)-1] } } return scripts, metaData, nil } func (ds *Datastore) GetScriptIDByName(ctx context.Context, name string, teamID *uint) (uint, error) { const selectStmt = ` SELECT id FROM scripts WHERE global_or_team_id = ? AND name = ? ` var globalOrTeamID uint if teamID != nil { globalOrTeamID = *teamID } var id uint if err := sqlx.GetContext(ctx, ds.reader(ctx), &id, selectStmt, globalOrTeamID, name); err != nil { if err == sql.ErrNoRows { return 0, notFound("Script").WithName(name) } return 0, ctxerr.Wrap(ctx, err, "get script by name") } return id, nil } func (ds *Datastore) GetHostScriptDetails(ctx context.Context, hostID uint, teamID *uint, opt fleet.ListOptions, hostPlatform string) ([]*fleet.HostScriptDetail, *fleet.PaginationMetadata, error) { var globalOrTeamID uint if teamID != nil { globalOrTeamID = *teamID } var extensionPatterns []string switch { case hostPlatform == "windows": // filter by .ps1 extension extensionPatterns = []string{`%.ps1`} case fleet.IsUnixLike(hostPlatform): // filter by .sh and .py extensions extensionPatterns = []string{`%.sh`, `%.py`} default: // no extension filter } type row struct { ScriptID uint `db:"script_id"` Name string `db:"name"` HSRID *uint `db:"hsr_id"` ExecutionID *string `db:"execution_id"` ExecutedAt *time.Time `db:"executed_at"` ExitCode *int64 `db:"exit_code"` } sql := ` WITH all_latest_activities AS ( -- Use window function to efficiently find the latest execution per script -- This is O(n) (a self-join approach would be O(n²)) SELECT * FROM ( SELECT id, host_id, script_id, execution_id, created_at, exit_code, 'completed' as source, ROW_NUMBER() OVER ( PARTITION BY script_id ORDER BY created_at DESC, id DESC ) AS row_num FROM host_script_results WHERE host_id = ? AND canceled = 0 ) completed_ranked WHERE row_num = 1 UNION ALL -- latest from upcoming_activities SELECT * FROM ( SELECT NULL as id, ua.host_id, sua.script_id, ua.execution_id, ua.created_at, NULL as exit_code, 'upcoming' as source, ROW_NUMBER() OVER ( PARTITION BY sua.script_id ORDER BY ua.created_at DESC, ua.id DESC ) AS row_num FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.host_id = ? AND ua.activity_type = 'script' ) upcoming_ranked WHERE row_num = 1 ) SELECT s.id AS script_id, s.name, latest.id AS hsr_id, latest.created_at AS executed_at, latest.execution_id, latest.exit_code FROM scripts s LEFT JOIN ( -- Pick the most recent between completed and upcoming for each script SELECT * FROM ( SELECT *, ROW_NUMBER() OVER ( PARTITION BY script_id ORDER BY CASE WHEN source = 'upcoming' THEN 1 ELSE 2 END, -- Prefer upcoming over completed created_at DESC, id DESC ) AS final_rn FROM all_latest_activities ) final_ranked WHERE final_rn = 1 ) latest ON s.id = latest.script_id WHERE (latest.host_id IS NULL OR latest.host_id = ?) AND s.global_or_team_id = ? ` args := []any{hostID, hostID, hostID, globalOrTeamID} if len(extensionPatterns) > 0 { likeClauses := make([]string, 0, len(extensionPatterns)) for _, ext := range extensionPatterns { likeClauses = append(likeClauses, "s.name LIKE ?") args = append(args, ext) } sql += ` AND ( ` + strings.Join(likeClauses, ` OR `) + ` ) ` } stmt, args, err := appendListOptionsWithCursorToSQLSecure(sql, args, &opt, hostScriptDetailsAllowedOrderKeys) if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "get host script details") } var rows []*row if err := sqlx.SelectContext(ctx, ds.reader(ctx), &rows, stmt, args...); err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "get host script details") } var metaData *fleet.PaginationMetadata if opt.IncludeMetadata { metaData = &fleet.PaginationMetadata{HasPreviousResults: opt.Page > 0} if len(rows) > int(opt.PerPage) { //nolint:gosec // dismiss G115 metaData.HasNextResults = true rows = rows[:len(rows)-1] } } results := make([]*fleet.HostScriptDetail, 0, len(rows)) for _, r := range rows { results = append(results, fleet.NewHostScriptDetail(hostID, r.ScriptID, r.Name, r.ExecutionID, r.ExecutedAt, r.ExitCode, r.HSRID)) } return results, metaData, nil } func (ds *Datastore) BatchSetScripts(ctx context.Context, tmID *uint, scripts []*fleet.Script) ([]fleet.ScriptResponse, error) { const loadExistingScripts = ` SELECT name FROM scripts WHERE global_or_team_id = ? AND name IN (?) ` const deleteAllScriptsInTeam = ` DELETE FROM scripts WHERE global_or_team_id = ? ` const unsetAllScriptsFromPolicies = `UPDATE policies SET script_id = NULL WHERE team_id = ?` const clearAllPendingExecutionsHSR = `DELETE FROM host_script_results WHERE exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND) AND script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)` const loadAffectedHostsAllPendingExecutionsUA = ` SELECT DISTINCT host_id FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.activity_type = 'script' AND ua.activated_at IS NOT NULL AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND) AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)` const clearAllPendingExecutionsUA = `DELETE FROM upcoming_activities USING upcoming_activities INNER JOIN script_upcoming_activities sua ON upcoming_activities.id = sua.upcoming_activity_id WHERE upcoming_activities.activity_type = 'script' AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND) AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ?)` const unsetScriptsNotInListFromPolicies = ` UPDATE policies SET script_id = NULL WHERE script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?)) ` const deleteScriptsNotInList = ` DELETE FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?) ` const clearPendingExecutionsNotInListHSR = `DELETE FROM host_script_results WHERE exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND) AND script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))` const loadAffectedHostsPendingExecutionsNotInListUA = ` SELECT DISTINCT host_id FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.activity_type = 'script' AND ua.activated_at IS NOT NULL AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND) AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))` const clearPendingExecutionsNotInListUA = `DELETE FROM upcoming_activities USING upcoming_activities INNER JOIN script_upcoming_activities sua ON upcoming_activities.id = sua.upcoming_activity_id WHERE upcoming_activities.activity_type = 'script' AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND) AND sua.script_id IN (SELECT id FROM scripts WHERE global_or_team_id = ? AND name NOT IN (?))` const insertNewOrEditedScript = ` INSERT INTO scripts ( team_id, global_or_team_id, name, script_content_id ) VALUES (?, ?, ?, ?) ON DUPLICATE KEY UPDATE script_content_id = VALUES(script_content_id), id=LAST_INSERT_ID(id) ` const clearPendingExecutionsWithObsoleteScriptHSR = `DELETE FROM host_script_results WHERE exit_code IS NULL AND (sync_request = 0 OR created_at >= NOW() - INTERVAL ? SECOND) AND script_id = ? AND script_content_id != ?` const loadAffectedHostsPendingExecutionsWithObsoleteScriptUA = ` SELECT DISTINCT host_id FROM upcoming_activities ua INNER JOIN script_upcoming_activities sua ON ua.id = sua.upcoming_activity_id WHERE ua.activity_type = 'script' AND ua.activated_at IS NOT NULL AND (ua.payload->'$.sync_request' = 0 OR ua.created_at >= NOW() - INTERVAL ? SECOND) AND sua.script_id = ? AND sua.script_content_id != ?` const clearPendingExecutionsWithObsoleteScriptUA = `DELETE FROM upcoming_activities USING upcoming_activities INNER JOIN script_upcoming_activities sua ON upcoming_activities.id = sua.upcoming_activity_id WHERE upcoming_activities.activity_type = 'script' AND (upcoming_activities.payload->'$.sync_request' = 0 OR upcoming_activities.created_at >= NOW() - INTERVAL ? SECOND) AND sua.script_id = ? AND sua.script_content_id != ?` const loadInsertedScripts = `SELECT id, team_id, name FROM scripts WHERE global_or_team_id = ?` // use a team id of 0 if no-team var globalOrTeamID uint if tmID != nil { globalOrTeamID = *tmID } // build a list of names for the incoming scripts, will keep the // existing ones if there's a match and no change incomingNames := make([]string, len(scripts)) // at the same time, index the incoming scripts keyed by name for ease // of processing incomingScripts := make(map[string]*fleet.Script, len(scripts)) for i, p := range scripts { incomingNames[i] = p.Name incomingScripts[p.Name] = p } var insertedScripts []fleet.ScriptResponse var activateAffectedHosts []uint if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var existingScripts []*fleet.Script if len(incomingNames) > 0 { // load existing scripts that match the incoming scripts by names stmt, args, err := sqlx.In(loadExistingScripts, globalOrTeamID, incomingNames) if err != nil { return ctxerr.Wrap(ctx, err, "build query to load existing scripts") } if err := sqlx.SelectContext(ctx, tx, &existingScripts, stmt, args...); err != nil { return ctxerr.Wrap(ctx, err, "load existing scripts") } } // figure out if we need to delete any scripts keepNames := make([]string, 0, len(incomingNames)) for _, p := range existingScripts { if newS := incomingScripts[p.Name]; newS != nil { keepNames = append(keepNames, p.Name) } } var ( scriptsStmt string scriptsArgs []any policiesStmt string policiesArgs []any executionsStmt string executionsArgs []any extraExecStmt string extraExecArgs []any err error affectedHostIDs []uint ) if len(keepNames) > 0 { // delete the obsolete scripts scriptsStmt, scriptsArgs, err = sqlx.In(deleteScriptsNotInList, globalOrTeamID, keepNames) if err != nil { return ctxerr.Wrap(ctx, err, "build statement to delete obsolete scripts") } policiesStmt, policiesArgs, err = sqlx.In(unsetScriptsNotInListFromPolicies, globalOrTeamID, keepNames) if err != nil { return ctxerr.Wrap(ctx, err, "build statement to unset obsolete scripts from policies") } executionsStmt, executionsArgs, err = sqlx.In(clearPendingExecutionsNotInListHSR, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames) if err != nil { return ctxerr.Wrap(ctx, err, "build statement to clear pending script executions from obsolete scripts") } loadAffectedStmt, args, err := sqlx.In(loadAffectedHostsPendingExecutionsNotInListUA, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames) if err != nil { return ctxerr.Wrap(ctx, err, "build query to load affected hosts for upcoming script executions") } if err := sqlx.SelectContext(ctx, tx, &affectedHostIDs, loadAffectedStmt, args...); err != nil { return ctxerr.Wrap(ctx, err, "load affected hosts for upcoming script executions") } extraExecStmt, extraExecArgs, err = sqlx.In(clearPendingExecutionsNotInListUA, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID, keepNames) if err != nil { return ctxerr.Wrap(ctx, err, "build statement to clear upcoming pending script executions from obsolete scripts") } } else { scriptsStmt = deleteAllScriptsInTeam scriptsArgs = []any{globalOrTeamID} policiesStmt = unsetAllScriptsFromPolicies policiesArgs = []any{globalOrTeamID} executionsStmt = clearAllPendingExecutionsHSR executionsArgs = []any{int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID} if err := sqlx.SelectContext(ctx, tx, &affectedHostIDs, loadAffectedHostsAllPendingExecutionsUA, int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID); err != nil { return ctxerr.Wrap(ctx, err, "load affected hosts for upcoming script executions") } extraExecStmt = clearAllPendingExecutionsUA extraExecArgs = []any{int(constants.MaxServerWaitTime.Seconds()), globalOrTeamID} } if _, err := tx.ExecContext(ctx, policiesStmt, policiesArgs...); err != nil { return ctxerr.Wrap(ctx, err, "unset obsolete scripts from policies") } if _, err := tx.ExecContext(ctx, executionsStmt, executionsArgs...); err != nil { return ctxerr.Wrap(ctx, err, "clear obsolete script pending executions") } if _, err := tx.ExecContext(ctx, extraExecStmt, extraExecArgs...); err != nil { return ctxerr.Wrap(ctx, err, "clear obsolete upcoming script pending executions") } if _, err := tx.ExecContext(ctx, scriptsStmt, scriptsArgs...); err != nil { return ctxerr.Wrap(ctx, err, "delete obsolete scripts") } activateAffectedHosts = affectedHostIDs // insert the new scripts and the ones that have changed for _, s := range incomingScripts { scRes, err := insertScriptContents(ctx, tx, s.ScriptContents) if err != nil { return ctxerr.Wrapf(ctx, err, "inserting script contents for script with name %q", s.Name) } contentID, _ := scRes.LastInsertId() insertRes, err := tx.ExecContext(ctx, insertNewOrEditedScript, tmID, globalOrTeamID, s.Name, uint(contentID)) //nolint:gosec // dismiss G115 if err != nil { return ctxerr.Wrapf(ctx, err, "insert new/edited script with name %q", s.Name) } scriptID, _ := insertRes.LastInsertId() if _, err := tx.ExecContext(ctx, clearPendingExecutionsWithObsoleteScriptHSR, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil { return ctxerr.Wrapf(ctx, err, "clear obsolete pending script executions with name %q", s.Name) } var affectedHosts []uint if err := sqlx.SelectContext(ctx, tx, &affectedHosts, loadAffectedHostsPendingExecutionsWithObsoleteScriptUA, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil { return ctxerr.Wrapf(ctx, err, "load affected hosts for upcoming script executions with name %q", s.Name) } activateAffectedHosts = append(activateAffectedHosts, affectedHosts...) if _, err = tx.ExecContext(ctx, clearPendingExecutionsWithObsoleteScriptUA, int(constants.MaxServerWaitTime.Seconds()), scriptID, contentID); err != nil { return ctxerr.Wrapf(ctx, err, "clear obsolete upcoming pending script executions with name %q", s.Name) } } if err := sqlx.SelectContext(ctx, tx, &insertedScripts, loadInsertedScripts, globalOrTeamID); err != nil { return ctxerr.Wrap(ctx, err, "load inserted scripts") } return nil }); err != nil { return nil, err } if err := ds.activateNextUpcomingActivityForBatchOfHosts(ctx, activateAffectedHosts); err != nil { return nil, ctxerr.Wrap(ctx, err, "activate next upcoming activity for batch of hosts") } return insertedScripts, nil } type hostMDMActions struct { LockRef *string `db:"lock_ref"` WipeRef *string `db:"wipe_ref"` UnlockRef *string `db:"unlock_ref"` UnlockPIN *string `db:"unlock_pin"` FleetPlatform string `db:"fleet_platform"` } func (ds *Datastore) GetHostLockWipeStatus(ctx context.Context, host *fleet.Host) (*fleet.HostLockWipeStatus, error) { const stmt = ` SELECT lock_ref, wipe_ref, unlock_ref, unlock_pin, fleet_platform FROM host_mdm_actions WHERE host_id = ? ` var mdmActions hostMDMActions hostPlatform := host.FleetPlatform() status := &fleet.HostLockWipeStatus{ HostFleetPlatform: hostPlatform, } if err := sqlx.GetContext(ctx, ds.reader(ctx), &mdmActions, stmt, host.ID); err != nil { if err == sql.ErrNoRows { // do not return a Not Found error, return the zero-value status, which // will report the correct states. return status, nil } return nil, ctxerr.Wrap(ctx, err, "get host lock/wipe status") } // if we have a fleet platform stored in host_mdm_actions, use it instead of // the host.FleetPlatform() because the platform can be overwritten with an // unknown OS name when a Wipe gets executed. if mdmActions.FleetPlatform != "" { hostPlatform = mdmActions.FleetPlatform status.HostFleetPlatform = hostPlatform } switch hostPlatform { case "darwin", "ios", "ipados": if mdmActions.UnlockPIN != nil && hostPlatform == "darwin" { // Unlock PIN is only available for macOS hosts status.UnlockPIN = *mdmActions.UnlockPIN } if mdmActions.UnlockRef != nil && hostPlatform == "darwin" { // the unlock reference is a timestamp // (we only store the timestamp for macOS unlocks) var err error status.UnlockRequestedAt, err = time.Parse(time.DateTime, *mdmActions.UnlockRef) if err != nil { // if the format is unexpected but there's something in UnlockRef, just // replace it with the current timestamp, it should still indicate that // an unlock was requested (e.g. in case someone plays with the data // directly in the DB and messes up the format). status.UnlockRequestedAt = time.Now().UTC() } } else if mdmActions.UnlockRef != nil && hostPlatform != "darwin" { // the unlock reference is an MDM command uuid cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.UnlockRef, host.UUID) if err != nil && !fleet.IsNotFound(err) { return nil, ctxerr.Wrap(ctx, err, "get unlock reference") } if fleet.IsNotFound(err) { ds.logger.ErrorContext(ctx, "orphan unlock MDM command reference", "host_id", host.ID, "command_uuid", *mdmActions.UnlockRef) } status.UnlockMDMCommand = cmd status.UnlockMDMCommandResult = cmdRes } if mdmActions.LockRef != nil { // the lock reference is an MDM command cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.LockRef, host.UUID) if err != nil && !fleet.IsNotFound(err) { return nil, ctxerr.Wrap(ctx, err, "get lock reference") } if fleet.IsNotFound(err) { ds.logger.ErrorContext(ctx, "orphan lock MDM command reference", "host_id", host.ID, "command_uuid", *mdmActions.LockRef) } status.LockMDMCommand = cmd status.LockMDMCommandResult = cmdRes // for ADE enrolled iDevices, we don't advance to "locked" until we have location data if status.LockMDMCommand != nil && (hostPlatform == "ios" || hostPlatform == "ipados") { _, err = ds.GetHostLocationData(ctx, host.ID) switch { case fleet.IsNotFound(err): status.LocationPending = true case err != nil: return nil, err } } } if mdmActions.WipeRef != nil { // the wipe reference is an MDM command cmd, cmdRes, err := ds.getHostMDMAppleCommand(ctx, *mdmActions.WipeRef, host.UUID) if err != nil && !fleet.IsNotFound(err) { return nil, ctxerr.Wrap(ctx, err, "get wipe reference") } if fleet.IsNotFound(err) { ds.logger.ErrorContext(ctx, "orphan wipe MDM command reference", "host_id", host.ID, "command_uuid", *mdmActions.WipeRef) } status.WipeMDMCommand = cmd status.WipeMDMCommandResult = cmdRes } case "windows", "linux": // lock and unlock references are scripts if mdmActions.LockRef != nil { hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.LockRef, scriptExecutionSearchOpts{IncludeCanceled: true}) if err != nil && !fleet.IsNotFound(err) { return nil, ctxerr.Wrap(ctx, err, "get lock reference script result") } if fleet.IsNotFound(err) { ds.logger.ErrorContext(ctx, "orphan lock script execution reference", "host_id", host.ID, "execution_id", *mdmActions.LockRef) } status.LockScript = hsr } if mdmActions.UnlockRef != nil { hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.UnlockRef, scriptExecutionSearchOpts{IncludeCanceled: true}) if err != nil && !fleet.IsNotFound(err) { return nil, ctxerr.Wrap(ctx, err, "get unlock reference script result") } if fleet.IsNotFound(err) { ds.logger.ErrorContext(ctx, "orphan unlock script execution reference", "host_id", host.ID, "execution_id", *mdmActions.UnlockRef) } status.UnlockScript = hsr } // wipe is an MDM command on Windows, a script on Linux if mdmActions.WipeRef != nil { if hostPlatform == "windows" { cmd, cmdRes, err := ds.getHostMDMWindowsCommand(ctx, *mdmActions.WipeRef, host.UUID) if err != nil && !fleet.IsNotFound(err) { return nil, ctxerr.Wrap(ctx, err, "get wipe reference") } if fleet.IsNotFound(err) { ds.logger.ErrorContext(ctx, "orphan wipe MDM command reference", "host_id", host.ID, "command_uuid", *mdmActions.WipeRef) } status.WipeMDMCommand = cmd status.WipeMDMCommandResult = cmdRes } else { hsr, err := ds.getHostScriptExecutionResultDB(ctx, ds.reader(ctx), *mdmActions.WipeRef, scriptExecutionSearchOpts{IncludeCanceled: true}) if err != nil && !fleet.IsNotFound(err) { return nil, ctxerr.Wrap(ctx, err, "get wipe reference script result") } if fleet.IsNotFound(err) { ds.logger.ErrorContext(ctx, "orphan wipe script execution reference", "host_id", host.ID, "execution_id", *mdmActions.WipeRef) } status.WipeScript = hsr } } } return status, nil } // GetHostsLockWipeStatusBatch gets the lock/unlock and wipe status for multiple hosts efficiently. func (ds *Datastore) GetHostsLockWipeStatusBatch(ctx context.Context, hosts []*fleet.Host) (map[uint]*fleet.HostLockWipeStatus, error) { if len(hosts) == 0 { return make(map[uint]*fleet.HostLockWipeStatus), nil } // Build list of host IDs for queries hostIDs := make([]uint, 0, len(hosts)) for _, host := range hosts { hostIDs = append(hostIDs, host.ID) } // Query all host_mdm_actions for these hosts stmt := ` SELECT host_id, lock_ref, wipe_ref, unlock_ref, unlock_pin, fleet_platform FROM host_mdm_actions WHERE host_id IN (?) ` query, args, err := sqlx.In(stmt, hostIDs) if err != nil { return nil, ctxerr.Wrap(ctx, err, "build IN query for host_mdm_actions") } var mdmActionsRows []struct { HostID uint `db:"host_id"` hostMDMActions } if err := sqlx.SelectContext(ctx, ds.reader(ctx), &mdmActionsRows, query, args...); err != nil { return nil, ctxerr.Wrap(ctx, err, "select host_mdm_actions batch") } // Collect all command/script UUIDs that need to be queried, organized by type and platform type refKey struct { uuid string hostUUID string hostID uint refType string // "lock", "unlock", "wipe" } appleCommandRefs := make([]refKey, 0) windowsCommandRefs := make([]refKey, 0) scriptRefs := make([]refKey, 0) // Build initial status map with platform info statusMap := make(map[uint]*fleet.HostLockWipeStatus, len(hosts)) mdmActionsMap := make(map[uint]*hostMDMActions) for _, row := range mdmActionsRows { mdmActionsMap[row.HostID] = &hostMDMActions{ LockRef: row.LockRef, WipeRef: row.WipeRef, UnlockRef: row.UnlockRef, UnlockPIN: row.UnlockPIN, FleetPlatform: row.FleetPlatform, } } // Initialize status for all hosts and collect refs to query for _, host := range hosts { fleetPlatform := host.FleetPlatform() status := &fleet.HostLockWipeStatus{ HostFleetPlatform: fleetPlatform, } mdmActions, hasMDMActions := mdmActionsMap[host.ID] statusMap[host.ID] = status if !hasMDMActions { continue } // Use stored platform if available if mdmActions.FleetPlatform != "" { fleetPlatform = mdmActions.FleetPlatform status.HostFleetPlatform = fleetPlatform } // Handle macOS unlock PIN (darwin only) if mdmActions.UnlockPIN != nil && fleetPlatform == "darwin" { status.UnlockPIN = *mdmActions.UnlockPIN } // Collect command/script references based on platform switch fleetPlatform { case "darwin", "ios", "ipados": // Apple platforms use MDM commands for lock, unlock (ios/ipados only), and wipe if mdmActions.LockRef != nil { appleCommandRefs = append(appleCommandRefs, refKey{ uuid: *mdmActions.LockRef, hostUUID: host.UUID, hostID: host.ID, refType: "lock", }) } if mdmActions.UnlockRef != nil && fleetPlatform != "darwin" { // iOS/iPadOS use MDM command for unlock, darwin uses timestamp appleCommandRefs = append(appleCommandRefs, refKey{ uuid: *mdmActions.UnlockRef, hostUUID: host.UUID, hostID: host.ID, refType: "unlock", }) } else if mdmActions.UnlockRef != nil && fleetPlatform == "darwin" { // For macOS, unlock_ref is a timestamp, parse it here unlockTime, err := time.Parse(time.DateTime, *mdmActions.UnlockRef) if err != nil { // Use current time if format is unexpected unlockTime = time.Now().UTC() } status.UnlockRequestedAt = unlockTime } if mdmActions.WipeRef != nil { appleCommandRefs = append(appleCommandRefs, refKey{ uuid: *mdmActions.WipeRef, hostUUID: host.UUID, hostID: host.ID, refType: "wipe", }) } case "windows": // Windows uses scripts for lock/unlock, MDM command for wipe if mdmActions.LockRef != nil { scriptRefs = append(scriptRefs, refKey{ uuid: *mdmActions.LockRef, hostID: host.ID, refType: "lock", }) } if mdmActions.UnlockRef != nil { scriptRefs = append(scriptRefs, refKey{ uuid: *mdmActions.UnlockRef, hostID: host.ID, refType: "unlock", }) } if mdmActions.WipeRef != nil { windowsCommandRefs = append(windowsCommandRefs, refKey{ uuid: *mdmActions.WipeRef, hostUUID: host.UUID, hostID: host.ID, refType: "wipe", }) } case "linux": // Linux uses scripts for lock, unlock, and wipe if mdmActions.LockRef != nil { scriptRefs = append(scriptRefs, refKey{ uuid: *mdmActions.LockRef, hostID: host.ID, refType: "lock", }) } if mdmActions.UnlockRef != nil { scriptRefs = append(scriptRefs, refKey{ uuid: *mdmActions.UnlockRef, hostID: host.ID, refType: "unlock", }) } if mdmActions.WipeRef != nil { scriptRefs = append(scriptRefs, refKey{ uuid: *mdmActions.WipeRef, hostID: host.ID, refType: "wipe", }) } } } // Batch query Apple MDM commands if len(appleCommandRefs) > 0 { cmdUUIDs := make([]string, 0, len(appleCommandRefs)) cmdUUIDMap := make(map[string][]refKey) for _, ref := range appleCommandRefs { if _, exists := cmdUUIDMap[ref.uuid]; !exists { cmdUUIDs = append(cmdUUIDs, ref.uuid) } cmdUUIDMap[ref.uuid] = append(cmdUUIDMap[ref.uuid], ref) } // Query commands cmdStmt := `SELECT command_uuid, request_type FROM nano_commands WHERE command_uuid IN (?)` cmdQuery, cmdArgs, err := sqlx.In(cmdStmt, cmdUUIDs) if err != nil { return nil, ctxerr.Wrap(ctx, err, "build IN query for apple commands") } var commands []struct { CommandUUID string `db:"command_uuid"` RequestType string `db:"request_type"` } if err := sqlx.SelectContext(ctx, ds.reader(ctx), &commands, cmdQuery, cmdArgs...); err != nil { return nil, ctxerr.Wrap(ctx, err, "select apple mdm commands batch") } commandMap := make(map[string]*fleet.MDMCommand) for _, cmd := range commands { commandMap[cmd.CommandUUID] = &fleet.MDMCommand{ CommandUUID: cmd.CommandUUID, RequestType: cmd.RequestType, } } // Query command results resultStmt := `SELECT command_uuid, id, status FROM nano_command_results WHERE command_uuid IN (?)` resultQuery, resultArgs, err := sqlx.In(resultStmt, cmdUUIDs) if err != nil { return nil, ctxerr.Wrap(ctx, err, "build IN query for apple command results") } var results []struct { CommandUUID string `db:"command_uuid"` ID string `db:"id"` Status string `db:"status"` } if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, resultQuery, resultArgs...); err != nil { return nil, ctxerr.Wrap(ctx, err, "select apple mdm command results batch") } // Build map of command_uuid -> host_uuid -> result resultMap := make(map[string]map[string]*fleet.MDMCommandResult) for _, res := range results { if resultMap[res.CommandUUID] == nil { resultMap[res.CommandUUID] = make(map[string]*fleet.MDMCommandResult) } // Only keep terminal statuses if res.Status == fleet.MDMAppleStatusAcknowledged || res.Status == fleet.MDMAppleStatusError || res.Status == fleet.MDMAppleStatusCommandFormatError { resultMap[res.CommandUUID][res.ID] = &fleet.MDMCommandResult{ CommandUUID: res.CommandUUID, Status: res.Status, HostUUID: res.ID, } } } // Assign commands and results to status objects for cmdUUID, refs := range cmdUUIDMap { cmd := commandMap[cmdUUID] for _, ref := range refs { status := statusMap[ref.hostID] var cmdRes *fleet.MDMCommandResult if resultMap[cmdUUID] != nil { cmdRes = resultMap[cmdUUID][ref.hostUUID] } switch ref.refType { case "lock": status.LockMDMCommand = cmd status.LockMDMCommandResult = cmdRes case "unlock": status.UnlockMDMCommand = cmd status.UnlockMDMCommandResult = cmdRes case "wipe": status.WipeMDMCommand = cmd status.WipeMDMCommandResult = cmdRes } } } } // Batch query Windows MDM commands if len(windowsCommandRefs) > 0 { cmdUUIDs := make([]string, 0, len(windowsCommandRefs)) cmdUUIDMap := make(map[string][]refKey) for _, ref := range windowsCommandRefs { if _, exists := cmdUUIDMap[ref.uuid]; !exists { cmdUUIDs = append(cmdUUIDs, ref.uuid) } cmdUUIDMap[ref.uuid] = append(cmdUUIDMap[ref.uuid], ref) } // Query commands cmdStmt := `SELECT command_uuid, target_loc_uri FROM windows_mdm_commands WHERE command_uuid IN (?)` cmdQuery, cmdArgs, err := sqlx.In(cmdStmt, cmdUUIDs) if err != nil { return nil, ctxerr.Wrap(ctx, err, "build IN query for windows commands") } var commands []struct { CommandUUID string `db:"command_uuid"` TargetLocURI string `db:"target_loc_uri"` } if err := sqlx.SelectContext(ctx, ds.reader(ctx), &commands, cmdQuery, cmdArgs...); err != nil { return nil, ctxerr.Wrap(ctx, err, "select windows mdm commands batch") } commandMap := make(map[string]*fleet.MDMCommand) for _, cmd := range commands { commandMap[cmd.CommandUUID] = &fleet.MDMCommand{ CommandUUID: cmd.CommandUUID, RequestType: cmd.TargetLocURI, } } // Query command results - JOIN with enrollments to get host_uuid resultStmt := ` SELECT wcr.command_uuid, we.host_uuid, wcr.status_code FROM windows_mdm_command_results wcr INNER JOIN mdm_windows_enrollments we ON wcr.enrollment_id = we.id WHERE wcr.command_uuid IN (?)` resultQuery, resultArgs, err := sqlx.In(resultStmt, cmdUUIDs) if err != nil { return nil, ctxerr.Wrap(ctx, err, "build IN query for windows command results") } var results []struct { CommandUUID string `db:"command_uuid"` HostUUID string `db:"host_uuid"` StatusCode string `db:"status_code"` } if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, resultQuery, resultArgs...); err != nil { return nil, ctxerr.Wrap(ctx, err, "select windows mdm command results batch") } // Build map of command_uuid -> host_uuid -> result resultMap := make(map[string]map[string]*fleet.MDMCommandResult) for _, res := range results { if resultMap[res.CommandUUID] == nil { resultMap[res.CommandUUID] = make(map[string]*fleet.MDMCommandResult) } resultMap[res.CommandUUID][res.HostUUID] = &fleet.MDMCommandResult{ CommandUUID: res.CommandUUID, Status: res.StatusCode, HostUUID: res.HostUUID, } } // Assign commands and results to status objects for cmdUUID, refs := range cmdUUIDMap { cmd := commandMap[cmdUUID] for _, ref := range refs { status := statusMap[ref.hostID] var cmdRes *fleet.MDMCommandResult if resultMap[cmdUUID] != nil { cmdRes = resultMap[cmdUUID][ref.hostUUID] } status.WipeMDMCommand = cmd status.WipeMDMCommandResult = cmdRes } } } // Batch query script results if len(scriptRefs) > 0 { execIDs := make([]string, 0, len(scriptRefs)) execIDMap := make(map[string]refKey) for _, ref := range scriptRefs { execIDs = append(execIDs, ref.uuid) execIDMap[ref.uuid] = ref } scriptStmt := ` SELECT execution_id, exit_code, canceled FROM host_script_results WHERE execution_id IN (?) ` scriptQuery, scriptArgs, err := sqlx.In(scriptStmt, execIDs) if err != nil { return nil, ctxerr.Wrap(ctx, err, "build IN query for script results") } var scriptResults []struct { ExecutionID string `db:"execution_id"` ExitCode *int64 `db:"exit_code"` Canceled bool `db:"canceled"` } if err := sqlx.SelectContext(ctx, ds.reader(ctx), &scriptResults, scriptQuery, scriptArgs...); err != nil { return nil, ctxerr.Wrap(ctx, err, "select script results batch") } scriptResultMap := make(map[string]*fleet.HostScriptResult) for _, sr := range scriptResults { scriptResultMap[sr.ExecutionID] = &fleet.HostScriptResult{ ExecutionID: sr.ExecutionID, ExitCode: sr.ExitCode, Canceled: sr.Canceled, } } // Assign script results to status objects for execID, ref := range execIDMap { status := statusMap[ref.hostID] scriptResult := scriptResultMap[execID] switch ref.refType { case "lock": status.LockScript = scriptResult case "unlock": status.UnlockScript = scriptResult case "wipe": status.WipeScript = scriptResult } } } return statusMap, nil } func (ds *Datastore) getHostMDMWindowsCommand(ctx context.Context, cmdUUID, hostUUID string) (*fleet.MDMCommand, *fleet.MDMCommandResult, error) { cmd, err := ds.getMDMCommand(ctx, ds.reader(ctx), cmdUUID) if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "get Windows MDM command") } // get the MDM command result, which may be not found (indicating the command doesn't exist). // If it is pending, then it returns 101, and result will be empty. cmdResults, err := ds.GetMDMWindowsCommandResults(ctx, cmdUUID, "") if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "get Windows MDM command result") } // each item in the slice returned by GetMDMWindowsCommandResults is // potentially a result for a different host, we need to find the one for // that specific host. var cmdRes *fleet.MDMCommandResult for _, r := range cmdResults { if r.HostUUID != hostUUID { continue } if r.Status == "101" || string(r.Result) == "" { // command is still pending continue } // all statuses for Windows indicate end of processing of the command // (there is no equivalent of "NotNow" or "Idle" as for Apple). cmdRes = r break } return cmd, cmdRes, nil } func (ds *Datastore) getHostMDMAppleCommand(ctx context.Context, cmdUUID, hostUUID string) (*fleet.MDMCommand, *fleet.MDMCommandResult, error) { cmd, err := ds.getMDMCommand(ctx, ds.reader(ctx), cmdUUID) if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "get Apple MDM command") } // get the MDM command result, which may be not found (indicating the command // is pending). Note that it doesn't return ErrNoRows if not found, it // returns success and an empty cmdRes slice. cmdResults, err := ds.GetMDMAppleCommandResults(ctx, cmdUUID, hostUUID) if err != nil { return nil, nil, ctxerr.Wrap(ctx, err, "get Apple MDM command result") } // filter by result status to preserve old behavior of this method where it doesn't return pending results. var cmdRes *fleet.MDMCommandResult for _, r := range cmdResults { if r.HostUUID != hostUUID { // this should never happen because we already filter by hostUUID, but just in case continue } if r.Status == fleet.MDMAppleStatusAcknowledged || r.Status == fleet.MDMAppleStatusError || r.Status == fleet.MDMAppleStatusCommandFormatError { cmdRes = r break } } return cmd, cmdRes, nil } // LockHostViaScript will create the script execution request and update // host_mdm_actions in a single transaction. func (ds *Datastore) LockHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error { var res *fleet.HostScriptResult return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var err error scRes, err := insertScriptContents(ctx, tx, request.ScriptContents) if err != nil { return err } id, _ := scRes.LastInsertId() request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115 res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true) if err != nil { return ctxerr.Wrap(ctx, err, "lock host via script create execution") } // on duplicate we don't clear any other existing state because at this // point in time, this is just a request to lock the host that is recorded, // it is pending execution. The host's state should be updated to "locked" // only when the script execution is successfully completed, and then any // unlock or wipe references should be cleared. const stmt = ` INSERT INTO host_mdm_actions ( host_id, lock_ref, fleet_platform ) VALUES (?,?,?) ON DUPLICATE KEY UPDATE lock_ref = VALUES(lock_ref) ` _, err = tx.ExecContext(ctx, stmt, request.HostID, res.ExecutionID, hostFleetPlatform, ) if err != nil { return ctxerr.Wrap(ctx, err, "lock host via script update mdm actions") } return nil }) } // UnlockHostViaScript will create the script execution request and update // host_mdm_actions in a single transaction. func (ds *Datastore) UnlockHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error { var res *fleet.HostScriptResult return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var err error scRes, err := insertScriptContents(ctx, tx, request.ScriptContents) if err != nil { return err } id, _ := scRes.LastInsertId() request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115 res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true) if err != nil { return ctxerr.Wrap(ctx, err, "unlock host via script create execution") } // on duplicate we don't clear any other existing state because at this // point in time, this is just a request to unlock the host that is // recorded, it is pending execution. The host's state should be updated to // "unlocked" only when the script execution is successfully completed, and // then any lock or wipe references should be cleared. const stmt = ` INSERT INTO host_mdm_actions ( host_id, unlock_ref, fleet_platform ) VALUES (?,?,?) ON DUPLICATE KEY UPDATE unlock_ref = VALUES(unlock_ref), unlock_pin = NULL ` _, err = tx.ExecContext(ctx, stmt, request.HostID, res.ExecutionID, hostFleetPlatform, ) if err != nil { return ctxerr.Wrap(ctx, err, "unlock host via script update mdm actions") } return err }) } // WipeHostViaScript creates the script execution request and updates the // host_mdm_actions table in a single transaction. func (ds *Datastore) WipeHostViaScript(ctx context.Context, request *fleet.HostScriptRequestPayload, hostFleetPlatform string) error { var res *fleet.HostScriptResult return ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { var err error scRes, err := insertScriptContents(ctx, tx, request.ScriptContents) if err != nil { return err } id, _ := scRes.LastInsertId() request.ScriptContentID = uint(id) //nolint:gosec // dismiss G115 res, err = ds.newHostScriptExecutionRequest(ctx, tx, request, true) if err != nil { return ctxerr.Wrap(ctx, err, "wipe host via script create execution") } // on duplicate we don't clear any other existing state because at this // point in time, this is just a request to wipe the host that is recorded, // it is pending execution, so if it was locked, it is still locked (so the // lock_ref info must still be there). const stmt = ` INSERT INTO host_mdm_actions ( host_id, wipe_ref, fleet_platform ) VALUES (?,?,?) ON DUPLICATE KEY UPDATE wipe_ref = VALUES(wipe_ref) ` _, err = tx.ExecContext(ctx, stmt, request.HostID, res.ExecutionID, hostFleetPlatform, ) if err != nil { return ctxerr.Wrap(ctx, err, "wipe host via script update mdm actions") } return err }) } // UnlockHostManually records a manual unlock request for the given host. // ts must be in UTC to ensure consistency with the STR_TO_DATE comparison in CleanAppleMDMLock. func (ds *Datastore) UnlockHostManually(ctx context.Context, hostID uint, hostFleetPlatform string, ts time.Time) error { const stmt = ` INSERT INTO host_mdm_actions ( host_id, unlock_ref, fleet_platform ) VALUES (?, ?, ?) ON DUPLICATE KEY UPDATE -- do not overwrite if a value is already set unlock_ref = IF(unlock_ref IS NULL, VALUES(unlock_ref), unlock_ref) ` // for macOS, the unlock_ref is just the timestamp at which the user first // requested to unlock the host. This then indicates in the host's status // that it's pending an unlock (which requires manual intervention by // entering a PIN on the device). The /unlock endpoint can be called multiple // times, so we record the timestamp of the first time it was requested and // from then on, the host is marked as "pending unlock" until the device is // actually unlocked with the PIN. The actual unlocking happens when the // device sends an Idle MDM request. unlockRef := ts.UTC().Format(time.DateTime) _, err := ds.writer(ctx).ExecContext(ctx, stmt, hostID, unlockRef, hostFleetPlatform) return ctxerr.Wrap(ctx, err, "record manual unlock host request") } func buildHostLockWipeStatusUpdateStmt(refCol string, succeeded bool, joinPart string, setUnlockRef bool) string { var alias string stmt := `UPDATE host_mdm_actions ` if joinPart != "" { stmt += `hma ` + joinPart alias = "hma." } stmt += ` SET ` if succeeded { switch refCol { case "lock_ref": // Note that this must not clear the unlock_pin, because recording the // lock request does generate the PIN and store it there to be used by an // eventual unlock. if !setUnlockRef { stmt += fmt.Sprintf("%sunlock_ref = NULL, %[1]swipe_ref = NULL", alias) } else { // Currently only used for Apple MDM devices. // We set the unlock_ref to current time since the device can be unlocked any time after the lock. // Apple MDM does not have a concept of unlock pending. // UTC_TIMESTAMP() is used to ensure timezone consistency with the comparison in CleanAppleMDMLock. stmt += fmt.Sprintf("%sunlock_ref = UTC_TIMESTAMP(), %[1]swipe_ref = NULL", alias) } case "unlock_ref": // a successful unlock clears itself as well as the lock ref, because // unlock is the default state so we don't need to keep its unlock_ref // around once it's confirmed. stmt += fmt.Sprintf("%slock_ref = NULL, %[1]sunlock_ref = NULL, %[1]sunlock_pin = NULL, %[1]swipe_ref = NULL", alias) case "wipe_ref": stmt += fmt.Sprintf("%slock_ref = NULL, %[1]sunlock_ref = NULL, %[1]sunlock_pin = NULL", alias) } } else { // if the action failed, then we clear the reference to that action itself so // the host stays in the previous state (it doesn't transition to the new // state). stmt += fmt.Sprintf("%s"+refCol+" = NULL", alias) } return stmt } func (ds *Datastore) UpdateHostLockWipeStatusFromAppleMDMResult(ctx context.Context, hostUUID, cmdUUID, requestType string, succeeded bool) error { // a bit of MDM protocol leaking in the mysql layer, but it's either that or // the other way around (MDM protocol would translate to database column) var refCol string var setUnlockRef bool switch requestType { case "EraseDevice": refCol = "wipe_ref" case "DeviceLock": refCol = "lock_ref" setUnlockRef = true case "EnableLostMode": refCol = "lock_ref" case "DisableLostMode": refCol = "unlock_ref" default: return nil } _, err := updateHostLockWipeStatusFromResultAndHostUUID(ctx, ds.writer(ctx), hostUUID, refCol, cmdUUID, succeeded, setUnlockRef) return err } func updateHostLockWipeStatusFromResultAndHostUUID( ctx context.Context, tx sqlx.ExtContext, hostUUID, refCol, cmdUUID string, succeeded bool, setUnlockRef bool, ) (int64, error) { stmt := buildHostLockWipeStatusUpdateStmt(refCol, succeeded, `JOIN hosts h ON hma.host_id = h.id`, setUnlockRef) stmt += ` WHERE h.uuid = ? AND hma.` + refCol + ` = ?` res, err := tx.ExecContext(ctx, stmt, hostUUID, cmdUUID) if err != nil { return 0, ctxerr.Wrap(ctx, err, "update host lock/wipe status from result via host uuid") } n, err := res.RowsAffected() if err != nil { return 0, ctxerr.Wrap(ctx, err, "get rows affected for host lock/wipe status update") } return n, nil } func updateHostLockWipeStatusFromResult(ctx context.Context, tx sqlx.ExtContext, hostID uint, refCol string, succeeded bool) error { stmt := buildHostLockWipeStatusUpdateStmt(refCol, succeeded, "", false) stmt += ` WHERE host_id = ?` _, err := tx.ExecContext(ctx, stmt, hostID) return ctxerr.Wrap(ctx, err, "update host lock/wipe status from result") } func (ds *Datastore) updateUninstallStatusFromResult(ctx context.Context, tx sqlx.ExtContext, hostID uint, executionID string, exitCode int) error { stmt := ` UPDATE host_software_installs SET uninstall_script_exit_code = ? WHERE execution_id = ? AND host_id = ? ` if _, err := tx.ExecContext(ctx, stmt, exitCode, executionID, hostID); err != nil { return ctxerr.Wrap(ctx, err, "update uninstall status from result") } // NOTE: no need to call activateNextUpcomingActivity here as this function // is called from SetHostScriptExecutionResult which will call it before // completing. return nil } func (ds *Datastore) CleanupUnusedScriptContents(ctx context.Context) error { deleteStmt := ` DELETE FROM script_contents WHERE NOT EXISTS ( SELECT 1 FROM host_script_results WHERE script_content_id = script_contents.id) AND NOT EXISTS ( SELECT 1 FROM scripts WHERE script_content_id = script_contents.id) AND NOT EXISTS ( SELECT 1 FROM software_installers si WHERE script_contents.id IN (si.install_script_content_id, si.post_install_script_content_id, si.uninstall_script_content_id) ) AND NOT EXISTS ( SELECT 1 FROM setup_experience_scripts WHERE script_content_id = script_contents.id ) AND NOT EXISTS ( SELECT 1 FROM script_upcoming_activities WHERE script_content_id = script_contents.id ) ` _, err := ds.writer(ctx).ExecContext(ctx, deleteStmt) if err != nil { return ctxerr.Wrap(ctx, err, "cleaning up unused script contents") } return nil } func (ds *Datastore) getOrGenerateScriptContentsID(ctx context.Context, contents string) (uint, error) { csum := md5ChecksumScriptContent(contents) scriptContentsID, err := ds.optimisticGetOrInsert(ctx, ¶meterizedStmt{ Statement: `SELECT id FROM script_contents WHERE md5_checksum = UNHEX(?)`, Args: []interface{}{csum}, }, ¶meterizedStmt{ Statement: `INSERT INTO script_contents (md5_checksum, contents) VALUES (UNHEX(?), ?)`, Args: []interface{}{csum, contents}, }, ) if err != nil { return 0, err } return scriptContentsID, nil } func teamIDEq(teamID1, teamID2 *uint) bool { sameTeamNoTeam := teamID1 == nil && teamID2 == nil sameTeamNumber := teamID1 != nil && teamID2 != nil && *teamID1 == *teamID2 return sameTeamNoTeam || sameTeamNumber } func (ds *Datastore) batchExecuteScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint, batchExecID string) error { script, err := ds.Script(ctx, scriptID) if err != nil { return fleet.NewInvalidArgumentError("script_id", err.Error()) } invalidHostIDPlatform := "batch-invalid-hostid" // We need full host info to check if hosts are able to run scripts, see svc.RunHostScript fullHosts := make([]*fleet.Host, 0, len(hostIDs)) // The execution results to be stored in the database executions := make([]fleet.BatchExecutionHost, 0, len(fullHosts)) // Check that all hosts exist before attempting to process them for _, hostID := range hostIDs { host, err := ds.Host(ctx, hostID) if err != nil { fullHosts = append(fullHosts, &fleet.Host{ ID: hostID, Platform: invalidHostIDPlatform, }) continue } fullHosts = append(fullHosts, host) } if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { for _, host := range fullHosts { // Host doesn't exist anymore if host.Platform == invalidHostIDPlatform { executions = append(executions, fleet.BatchExecutionHost{ HostID: host.ID, Error: &fleet.BatchExecuteInvalidHost, }) continue } // Non-orbit-enrolled host (iOS, android) noNodeKey := host.OrbitNodeKey == nil || *host.OrbitNodeKey == "" // Scripts disabled on host scriptsDisabled := host.ScriptsEnabled != nil && !*host.ScriptsEnabled if noNodeKey || scriptsDisabled { executions = append(executions, fleet.BatchExecutionHost{ HostID: host.ID, Error: &fleet.BatchExecuteIncompatibleFleetd, }) continue } if !fleet.ValidateScriptPlatform(script.Name, host.Platform) { executions = append(executions, fleet.BatchExecutionHost{ HostID: host.ID, Error: &fleet.BatchExecuteIncompatiblePlatform, }) continue } executionID, _, err := ds.insertNewHostScriptExecution(ctx, tx, &fleet.HostScriptRequestPayload{ HostID: host.ID, UserID: userID, ScriptID: &script.ID, ScriptContentID: script.ScriptContentID, }, false) if err != nil { return ctxerr.Wrap(ctx, err, "queueing script for bulk execution") } executions = append(executions, fleet.BatchExecutionHost{ HostID: host.ID, ExecutionID: &executionID, }) } _, err := tx.ExecContext( ctx, `INSERT INTO batch_activities (execution_id, script_id, status, activity_type, num_targeted, started_at) VALUES (?, ?, ?, ?, ?, NOW()) ON DUPLICATE KEY UPDATE status = VALUES(status), started_at = VALUES(started_at)`, batchExecID, script.ID, fleet.ScheduledBatchExecutionStarted, fleet.BatchExecutionActivityScript, len(hostIDs), ) if err != nil { return ctxerr.Wrap(ctx, err, "failed to insert new batch execution") } args := make([]map[string]any, 0, len(executions)) for _, execHost := range executions { args = append(args, map[string]any{ "batch_id": batchExecID, "host_id": execHost.HostID, "host_execution_id": execHost.ExecutionID, "error": execHost.Error, }) } insertStmt := ` INSERT INTO batch_activity_host_results ( batch_execution_id, host_id, host_execution_id, error ) VALUES ( :batch_id, :host_id, :host_execution_id, :error ) ON DUPLICATE KEY UPDATE host_execution_id = VALUES(host_execution_id), error = VALUES(error)` if _, err := sqlx.NamedExecContext(ctx, tx, insertStmt, args); err != nil { return ctxerr.Wrap(ctx, err, "associating script executions with batch job") } return nil }); err != nil { return fmt.Errorf("creating bulk execution order: %w", err) } return nil } func (ds *Datastore) BatchExecuteScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint) (string, error) { batchExecID := uuid.New().String() script, err := ds.Script(ctx, scriptID) if err != nil { return "", fleet.NewInvalidArgumentError("script_id", err.Error()) } for _, hostID := range hostIDs { host, err := ds.HostLite(ctx, hostID) if err != nil { return "", fmt.Errorf("unable to load host information for %d: %w", hostID, err) } if !teamIDEq(host.TeamID, script.TeamID) { return "", ctxerr.Errorf(ctx, "all hosts must be on the same fleet as the script") } } if err := ds.batchExecuteScript(ctx, userID, scriptID, hostIDs, batchExecID); err != nil { return "", ctxerr.Wrap(ctx, err, "immediate batch execution") } return batchExecID, nil } func (ds *Datastore) BatchScheduleScript(ctx context.Context, userID *uint, scriptID uint, hostIDs []uint, notBefore time.Time) (string, error) { batchExecID := uuid.New().String() const batchActivitiesStmt = `INSERT INTO batch_activities (execution_id, job_id, script_id, user_id, status, activity_type, num_targeted) VALUES (?, ?, ?, ?, ?, ?, ?)` const batchHostsStmt = `INSERT INTO batch_activity_host_results (batch_execution_id, host_id) VALUES (:exec_id, :host_id)` argBytes, err := json.Marshal(fleet.BatchActivityScriptJobArgs{ ExecutionID: batchExecID, }) if err != nil { return "", ctxerr.Wrap(ctx, err, "encooding job args") } if err := ds.withTx(ctx, func(tx sqlx.ExtContext) error { job, err := ds.NewJob(ctx, &fleet.Job{ Name: fleet.BatchActivityScriptsJobName, Args: (*json.RawMessage)(&argBytes), State: fleet.JobStateQueued, NotBefore: notBefore.UTC(), }) if err != nil { return ctxerr.Wrap(ctx, err, "creating new job") } _, err = tx.ExecContext( ctx, batchActivitiesStmt, batchExecID, job.ID, scriptID, userID, fleet.ScheduledBatchExecutionScheduled, fleet.BatchExecutionActivityScript, len(hostIDs), ) if err != nil { return ctxerr.Wrap(ctx, err, "inserting new batch activity") } args := make([]map[string]any, 0, len(hostIDs)) for _, hostID := range hostIDs { args = append(args, map[string]any{ "exec_id": batchExecID, "host_id": hostID, }) } if _, err := sqlx.NamedExecContext(ctx, tx, batchHostsStmt, args); err != nil { return ctxerr.Wrap(ctx, err, "inserting batch host results") } return nil }); err != nil { return "", ctxerr.Wrap(ctx, err, "creating scheduled script execution") } return batchExecID, nil } func (ds *Datastore) CancelBatchScript(ctx context.Context, executionID string) error { stmt := ` SELECT bahr.host_execution_id, bahr.host_id FROM batch_activity_host_results bahr LEFT JOIN host_script_results hsr ON bahr.host_execution_id = hsr.execution_id -- I think? WHERE bahr.batch_execution_id = ? AND hsr.canceled = 0 AND hsr.exit_code IS NULL AND bahr.error IS NULL` stmtSetCanceled := ` UPDATE batch_activities ba SET finished_at = NOW(), status = 'finished', canceled = 1, num_canceled = (SELECT COUNT(*) FROM batch_activity_host_results WHERE batch_execution_id = ba.execution_id) WHERE ba.execution_id = ?` stmtCanceled := ` UPDATE batch_activities SET canceled = 1 WHERE execution_id = ?` activity, err := ds.GetBatchActivity(ctx, executionID) if err != nil { return ctxerr.Wrap(ctx, err, "getting batch activity") } if activity.Status == fleet.ScheduledBatchExecutionFinished { return nil } if err := ds.withRetryTxx(ctx, func(tx sqlx.ExtContext) error { // If job worker exists, mark it as complete to stop it from running if jobID := activity.JobID; jobID != nil { job, err := ds.GetJob(ctx, *jobID) if err != nil { return ctxerr.Wrap(ctx, err, "failed to find job associated with batch activity") } job.State = fleet.JobStateSuccess if _, err := ds.updateJob(ctx, tx, *jobID, job); err != nil { return ctxerr.Wrap(ctx, err, "updating batch activity job") } } if activity.Status == fleet.ScheduledBatchExecutionStarted { // If the batch activity has started, we need to cancel anything in progress or queued toCancel := []struct { HostExecutionID string `db:"host_execution_id"` HostID uint `db:"host_id"` }{} if err := sqlx.SelectContext(ctx, tx, &toCancel, stmt, executionID); err != nil { return ctxerr.Wrap(ctx, err, "selecting hosts to cancel") } for _, host := range toCancel { if _, err := ds.cancelHostUpcomingActivity(ctx, tx, host.HostID, host.HostExecutionID, true); err != nil { return ctxerr.Wrap(ctx, err, "canceling upcoming activity") } } if _, err := tx.ExecContext(ctx, stmtCanceled, executionID); err != nil { return ctxerr.Wrap(ctx, err, "setting canceled column") } if err := ds.markActivitiesAsCompleted(ctx, tx); err != nil { return ctxerr.Wrap(ctx, err, "marking job as complete and summarizing counts") } } else { // The batch activity is scheduled, but not started if _, err := tx.ExecContext(ctx, stmtSetCanceled, executionID); err != nil { return ctxerr.Wrap(ctx, err, "setting canceled host count") } } return nil }); err != nil { return ctxerr.Wrap(ctx, err, "cancel batch script db transaction") } return nil } func (ds *Datastore) GetBatchActivity(ctx context.Context, executionID string) (*fleet.BatchActivity, error) { const stmt = ` SELECT ba.id, ba.script_id, s.name as script_name, ba.execution_id, ba.user_id, ba.job_id, ba.status, ba.activity_type, ba.num_targeted, ba.num_pending, ba.num_ran, ba.num_errored, ba.num_incompatible, ba.num_canceled, ba.created_at, ba.updated_at, ba.started_at, ba.finished_at, ba.canceled FROM batch_activities ba LEFT JOIN scripts s ON s.id = ba.script_id WHERE execution_id = ?` batchActivity := &fleet.BatchActivity{} if err := sqlx.GetContext(ctx, ds.reader(ctx), batchActivity, stmt, executionID); err != nil { return nil, ctxerr.Wrap(ctx, err, "selecting batch activity") } return batchActivity, nil } func (ds *Datastore) GetBatchActivityHostResults(ctx context.Context, executionID string) ([]*fleet.BatchActivityHostResult, error) { const stmt = ` SELECT id, batch_execution_id, host_id, host_execution_id, error FROM batch_activity_host_results WHERE batch_execution_id = ?` results := []*fleet.BatchActivityHostResult{} if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, stmt, executionID); err != nil { return nil, ctxerr.Wrap(ctx, err, "selecting batch activity host results") } return results, nil } func (ds *Datastore) RunScheduledBatchActivity(ctx context.Context, executionID string) error { batchActivity, err := ds.GetBatchActivity(ctx, executionID) if err != nil { return ctxerr.Wrap(ctx, err, "getting batch activity") } if batchActivity.Status != fleet.ScheduledBatchExecutionScheduled { return ctxerr.New(ctx, "batch job has already been started") } if batchActivity.Canceled { return ctxerr.New(ctx, "batch job was canceled") } if batchActivity.ScriptID == nil { return ctxerr.New(ctx, "no script ID present in batch activity") } script, err := ds.Script(ctx, *batchActivity.ScriptID) if err != nil { return ctxerr.Wrap(ctx, err, "could not get script") } results, err := ds.GetBatchActivityHostResults(ctx, executionID) if err != nil { return ctxerr.Wrap(ctx, err, "getting batch activity host results") } hostIDs := []uint{} for _, result := range results { hostIDs = append(hostIDs, result.HostID) } if err := ds.batchExecuteScript(ctx, batchActivity.UserID, script.ID, hostIDs, batchActivity.BatchExecutionID); err != nil { return ctxerr.Wrap(ctx, err, "scheduled batch script execution") } return nil } // Deprecated; will be removed in favor of ListBatchScriptExecutions when the batch script details page is ready. func (ds *Datastore) BatchExecuteSummary(ctx context.Context, executionID string) (*fleet.BatchActivity, error) { stmtExecutions := ` SELECT COUNT(*) as num_targeted, COUNT(bsehr.error) as num_did_not_run, COUNT(CASE WHEN hsr.exit_code = 0 THEN 1 END) as num_succeeded, COUNT(CASE WHEN hsr.exit_code <> 0 THEN 1 END) as num_failed, COUNT(CASE WHEN hsr.canceled = 1 AND hsr.exit_code IS NULL THEN 1 END) as num_cancelled FROM batch_activity_host_results bsehr LEFT JOIN host_script_results hsr ON bsehr.host_execution_id = hsr.execution_id WHERE bsehr.batch_execution_id = ?` stmtScriptDetails := ` SELECT script_id, s.name as script_name, s.team_id as team_id, bse.created_at as created_at FROM batch_activities bse JOIN scripts s ON bse.script_id = s.id WHERE bse.execution_id = ?` var summary fleet.BatchActivity var temp_summary struct { NumTargeted uint `db:"num_targeted"` NumDidNotRun uint `db:"num_did_not_run"` NumSucceeded uint `db:"num_succeeded"` NumFailed uint `db:"num_failed"` NumCancelled uint `db:"num_cancelled"` } // Fill out the execution details if err := sqlx.GetContext(ctx, ds.reader(ctx), &temp_summary, stmtExecutions, executionID); err != nil { return nil, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary") } summary.NumTargeted = &temp_summary.NumTargeted // NumRan is the number of hosts that actually ran the script successfully. summary.NumRan = &temp_summary.NumSucceeded // NumErrored is the number of hosts that errored out, which includes // both failed and did not run. summary.NumErrored = ptr.Uint(temp_summary.NumFailed + temp_summary.NumDidNotRun) // NumFailed is the number of hosts that were canceled before execution. summary.NumCanceled = &temp_summary.NumCancelled // NumPending is the number of hosts that are pending execution. summary.NumPending = ptr.Uint(temp_summary.NumTargeted - (temp_summary.NumSucceeded + temp_summary.NumFailed + temp_summary.NumDidNotRun + temp_summary.NumCancelled)) // Fill out the script details if err := sqlx.GetContext(ctx, ds.reader(ctx), &summary, stmtScriptDetails, executionID); err != nil { return nil, ctxerr.Wrap(ctx, err, "selecting script information for bulk execution summary") } if summary.TeamID == nil { summary.TeamID = ptr.Uint(0) } return &summary, nil } func (ds *Datastore) ListBatchScriptExecutions(ctx context.Context, filter fleet.BatchExecutionStatusFilter) ([]fleet.BatchActivity, error) { stmtExecutions := ` SELECT * FROM ( -- If batch is finished, get the cached host result counts SELECT COALESCE(ba.num_targeted, 0) AS num_targeted, COALESCE(ba.num_incompatible, 0) AS num_incompatible, COALESCE(ba.num_ran, 0) AS num_ran, COALESCE(ba.num_errored, 0) AS num_errored, COALESCE(ba.num_canceled, 0) AS num_canceled, COALESCE(ba.num_pending, 0) AS num_pending, ba.execution_id, ba.script_id, ba.status, ba.canceled, ba.finished_at, ba.started_at, s.name AS script_name, s.global_or_team_id AS team_id, ba.created_at AS created_at, j.not_before AS not_before, ba.id AS id FROM batch_activities ba JOIN scripts s ON ba.script_id = s.id LEFT JOIN jobs j ON j.id = ba.job_id WHERE ( %s ) AND ba.status = 'finished' UNION ALL -- If batch is not finished, calculate the host result counts live. SELECT COUNT(bahr.host_id) AS num_targeted, COUNT(bahr.error) AS num_incompatible, COUNT(IF(hsr.exit_code = 0, 1, NULL)) AS num_ran, COUNT(IF(hsr.exit_code <> 0, 1, NULL)) AS num_errored, COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba.canceled = 1), 1, NULL)) AS num_cancelled, ( COUNT(bahr.host_id) - COUNT(bahr.error) - COUNT(IF(hsr.exit_code = 0, 1, NULL)) - COUNT(IF(hsr.exit_code <> 0, 1, NULL)) - COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba.canceled = 1), 1, NULL)) ) AS num_pending, ba.execution_id, ba.script_id, ba.status, ba.canceled, ba.finished_at, ba.started_at, s.name AS script_name, s.global_or_team_id AS team_id, ba.created_at AS created_at, j.not_before AS not_before, ba.id AS id FROM batch_activities ba LEFT JOIN batch_activity_host_results bahr ON ba.execution_id = bahr.batch_execution_id LEFT JOIN host_script_results hsr ON bahr.host_execution_id = hsr.execution_id JOIN scripts s ON ba.script_id = s.id LEFT JOIN jobs j ON j.id = ba.job_id WHERE ( %s ) AND ba.status <> 'finished' GROUP BY ba.id ) AS u ORDER BY %s LIMIT %d OFFSET %d ` limit := 10 offset := 0 args := []any{} orderBy := []string{"u.created_at DESC", "u.id DESC"} whereClauses := make([]string, 0, 2) // If an execution ID is provided, use it to filter the results. if filter.ExecutionID != nil && *filter.ExecutionID != "" { whereClauses = append(whereClauses, "ba.execution_id = ?") args = append(args, *filter.ExecutionID) } else { // Otherwise filter by status and/or team ID. if filter.Status != nil && *filter.Status != "" { whereClauses = append(whereClauses, "ba.status = ?") args = append(args, *filter.Status) switch *filter.Status { case string(fleet.ScheduledBatchExecutionScheduled): orderBy = append([]string{"u.not_before ASC"}, orderBy...) case string(fleet.ScheduledBatchExecutionStarted): orderBy = append([]string{"u.started_at DESC"}, orderBy...) case string(fleet.ScheduledBatchExecutionFinished): orderBy = append([]string{"u.finished_at DESC"}, orderBy...) default: // no additional ordering } } if filter.TeamID != nil { whereClauses = append(whereClauses, "s.global_or_team_id = ?") args = append(args, *filter.TeamID) } } // Double up the args to use them in both WHERE clauses. args = append(args, args...) // Use pagination parameters if provided. if filter.Limit != nil { limit = int(*filter.Limit) //nolint:gosec // dismiss G115 } if filter.Offset != nil { offset = int(*filter.Offset) //nolint:gosec // dismiss G115 } where := strings.Join(whereClauses, " AND ") stmtExecutions = fmt.Sprintf(stmtExecutions, where, where, strings.Join(orderBy, ", "), limit, offset) var summary []fleet.BatchActivity if err := sqlx.SelectContext(ctx, ds.reader(ctx), &summary, stmtExecutions, args...); err != nil { return nil, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary") } return summary, nil } func (ds *Datastore) CountBatchScriptExecutions(ctx context.Context, filter fleet.BatchExecutionStatusFilter) (int64, error) { stmtExecutions := ` SELECT COUNT(*) FROM batch_activities ba JOIN scripts s ON ba.script_id = s.id WHERE %s ` args := []any{} whereClauses := make([]string, 0, 2) if filter.Status != nil && *filter.Status != "" { whereClauses = append(whereClauses, "ba.status = ?") args = append(args, *filter.Status) } if filter.TeamID != nil { whereClauses = append(whereClauses, "s.global_or_team_id = ?") args = append(args, *filter.TeamID) } where := strings.Join(whereClauses, " AND ") stmtExecutions = fmt.Sprintf(stmtExecutions, where) var count int64 if err := sqlx.GetContext(ctx, ds.reader(ctx), &count, stmtExecutions, args...); err != nil { return 0, ctxerr.Wrap(ctx, err, "selecting execution information for bulk execution summary") } return count, nil } func (ds *Datastore) markActivitiesAsCompleted(ctx context.Context, tx sqlx.ExtContext) error { const stmt = ` UPDATE batch_activities AS ba JOIN ( SELECT ba2.id AS batch_id, COUNT(bahr.host_id) AS num_targeted, COUNT(bahr.error) AS num_incompatible, COUNT(IF(hsr.exit_code = 0, 1, NULL)) AS num_ran, COUNT(IF(hsr.exit_code <> 0, 1, NULL)) AS num_errored, COUNT(IF((hsr.canceled = 1 AND hsr.exit_code IS NULL) OR (hsr.host_id IS NULL AND bahr.error is NULL AND ba2.canceled = 1), 1, NULL)) AS num_canceled FROM batch_activities AS ba2 LEFT JOIN batch_activity_host_results AS bahr ON ba2.execution_id = bahr.batch_execution_id LEFT JOIN host_script_results AS hsr ON bahr.host_execution_id = hsr.execution_id WHERE ba2.status = 'started' GROUP BY ba2.id HAVING (num_incompatible + num_ran + num_errored + num_canceled) >= num_targeted ) AS agg ON agg.batch_id = ba.id SET ba.status = 'finished', ba.finished_at = NOW(), ba.num_targeted = agg.num_targeted, ba.num_incompatible = agg.num_incompatible, ba.num_ran = agg.num_ran, ba.num_errored = agg.num_errored, ba.num_canceled = agg.num_canceled, ba.num_pending = 0 WHERE ba.status = 'started'; ` // TODO -- use `RETURNING` to return the IDs of the updated activities? _, err := tx.ExecContext(ctx, stmt) if err != nil { return ctxerr.Wrap(ctx, err, "marking activities as completed") } return nil } func (ds *Datastore) MarkActivitiesAsCompleted(ctx context.Context) error { return ds.markActivitiesAsCompleted(ctx, ds.writer(ctx)) }